Add observation queue to ACT + refactor into _queues
This commit is contained in:
@@ -99,6 +99,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
||||
|
||||
del df["timestamp_utc"]
|
||||
|
||||
# sanity check
|
||||
has_nan = df.isna().any().any()
|
||||
assert not has_nan
|
||||
|
||||
# sanity check episode indices go from 0 to n-1
|
||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||
|
||||
@@ -36,6 +36,9 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.utils import (
|
||||
populate_queues,
|
||||
)
|
||||
|
||||
|
||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
@@ -73,6 +76,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
|
||||
self.model = ACT(config)
|
||||
|
||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
@@ -84,7 +90,11 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
self._ensembled_actions = None
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
self._queues = {
|
||||
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -98,10 +108,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if self._ensembled_actions is None:
|
||||
@@ -121,7 +136,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
@@ -129,8 +147,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
return self._queues["action"].popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
|
||||
Reference in New Issue
Block a user