tests are passing for aloha/act policies, removes abstract policy

This commit is contained in:
Cadene
2024-04-09 03:28:56 +00:00
parent 73dfa3c8e3
commit 6902e01db0
6 changed files with 90 additions and 167 deletions

View File

@@ -150,6 +150,8 @@ class TDMPCPolicy(nn.Module):
t0 = step == 0
self.eval()
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
@@ -171,7 +173,7 @@ class TDMPCPolicy(nn.Module):
actions.append(action)
action = torch.stack(actions)
# self.act returns an action for 1 timestep only, so we copy it over `n_action_steps` time
# tdmpc returns an action for 1 timestep only, so we copy it over `n_action_steps` time
if i in range(self.n_action_steps):
self._queues["action"].append(action)