Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl

This commit is contained in:
Alexander Soare
2024-04-05 17:52:39 +01:00
49 changed files with 256 additions and 2563 deletions

View File

@@ -168,21 +168,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
nn.init.xavier_uniform_(p)
@torch.no_grad()
def select_actions(self, observation, step_count):
# TODO(rcadene): remove unused step_count
del step_count
def select_actions(self, batch, *_):
# TODO(now): Implement queueing mechanism.
self.eval()
self._preprocess_batch(batch)
# TODO(rcadene): remove hack
# add 1 camera dimension
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
obs_dict = {
"image": observation["image", "top"],
"agent_pos": observation["state"],
}
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
# TODO(now): What's up with this 0.182?
action = self.forward(
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"]
)
if self.cfg.temporal_agg:
# TODO(rcadene): implement temporal aggregation
@@ -197,9 +191,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
# take first predicted action or n first actions
action = action[: self.n_action_steps]
return action
return action[: self.n_action_steps]
def __call__(self, *args, **kwargs):
# TODO(now): Temporary bridge.