Add observation queue to ACT + refactor into _queues

This commit is contained in:
Remi Cadene
2024-05-29 10:21:50 +00:00
parent b24f2d0028
commit 8c1dd0e263
2 changed files with 26 additions and 4 deletions

View File

@@ -99,6 +99,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
del df["timestamp_utc"]
# sanity check
has_nan = df.isna().any().any()
assert not has_nan
# sanity check episode indices go from 0 to n-1
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
expected_ep_ids = list(range(df["episode_index"].max() + 1))

View File

@@ -36,6 +36,9 @@ from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import (
populate_queues,
)
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
@@ -73,6 +76,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
config.output_shapes, config.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
@@ -84,7 +90,11 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
self._queues = {
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
@@ -98,10 +108,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
batch = self.normalize_inputs(batch)
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
@@ -121,7 +136,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
@@ -129,8 +147,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
self._queues["action"].extend(actions.transpose(0, 1))
return self._queues["action"].popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""