forked from tangger/lerobot
Merge remote-tracking branch 'Cadene/user/rcadene/2024_03_31_remove_torchrl' into refactor_act_remove_torchrl
This commit is contained in:
@@ -168,21 +168,15 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_actions(self, observation, step_count):
|
||||
# TODO(rcadene): remove unused step_count
|
||||
del step_count
|
||||
|
||||
def select_actions(self, batch, *_):
|
||||
# TODO(now): Implement queueing mechanism.
|
||||
self.eval()
|
||||
self._preprocess_batch(batch)
|
||||
|
||||
# TODO(rcadene): remove hack
|
||||
# add 1 camera dimension
|
||||
observation["image", "top"] = observation["image", "top"].unsqueeze(1)
|
||||
|
||||
obs_dict = {
|
||||
"image": observation["image", "top"],
|
||||
"agent_pos": observation["state"],
|
||||
}
|
||||
action = self._forward(qpos=obs_dict["agent_pos"] * 0.182, image=obs_dict["image"])
|
||||
# TODO(now): What's up with this 0.182?
|
||||
action = self.forward(
|
||||
robot_state=batch["observation.state"] * 0.182, image=batch["observation.images.top"]
|
||||
)
|
||||
|
||||
if self.cfg.temporal_agg:
|
||||
# TODO(rcadene): implement temporal aggregation
|
||||
@@ -197,9 +191,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# exp_weights = torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||
# raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
|
||||
|
||||
# take first predicted action or n first actions
|
||||
action = action[: self.n_action_steps]
|
||||
return action
|
||||
return action[: self.n_action_steps]
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO(now): Temporary bridge.
|
||||
|
||||
Reference in New Issue
Block a user