From 8c1dd0e263617548f1d0785b22e16d497832f431 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 29 May 2024 10:21:50 +0000 Subject: [PATCH] Add observation queue to ACT + refactor into _queues --- .../push_dataset_to_hub/aloha_dora_format.py | 4 +++ lerobot/common/policies/act/modeling_act.py | 26 ++++++++++++++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py index b897bd13..b10e7b04 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_dora_format.py @@ -99,6 +99,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int): del df["timestamp_utc"] + # sanity check + has_nan = df.isna().any().any() + assert not has_nan + # sanity check episode indices go from 0 to n-1 ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] expected_ep_ids = list(range(df["episode_index"].max() + 1)) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 0ba85c44..fed4fb67 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -36,6 +36,9 @@ from torchvision.ops.misc import FrozenBatchNorm2d from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.utils import ( + populate_queues, +) class ACTPolicy(nn.Module, PyTorchModelHubMixin): @@ -73,6 +76,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): config.output_shapes, config.output_normalization_modes, dataset_stats ) + # queues are populated during rollout of the policy, they contain the n latest observations and actions + self._queues = None + self.model = ACT(config) self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] @@ -84,7 +90,11 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): if self.config.temporal_ensemble_momentum is not None: self._ensembled_actions = None else: - self._action_queue = deque([], maxlen=self.config.n_action_steps) + self._queues = { + "observation.images": deque(maxlen=self.config.n_obs_steps), + "observation.state": deque(maxlen=self.config.n_obs_steps), + "action": deque(maxlen=self.config.n_action_steps), + } @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -98,10 +108,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): batch = self.normalize_inputs(batch) batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) # If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return # the first action. if self.config.temporal_ensemble_momentum is not None: + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] if self._ensembled_actions is None: @@ -121,7 +136,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. - if len(self._action_queue) == 0: + if len(self._queues["action"]) == 0: + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.model(batch)[0][:, : self.config.n_action_steps] # TODO(rcadene): make _forward return output dictionary? @@ -129,8 +147,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin): # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() + self._queues["action"].extend(actions.transpose(0, 1)) + return self._queues["action"].popleft() def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation."""