forked from tangger/lerobot
Add observation queue to ACT + refactor into _queues
This commit is contained in:
@@ -99,6 +99,10 @@ def load_from_raw(raw_dir: Path, out_dir: Path, fps: int):
|
|||||||
|
|
||||||
del df["timestamp_utc"]
|
del df["timestamp_utc"]
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
has_nan = df.isna().any().any()
|
||||||
|
assert not has_nan
|
||||||
|
|
||||||
# sanity check episode indices go from 0 to n-1
|
# sanity check episode indices go from 0 to n-1
|
||||||
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")]
|
||||||
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
expected_ep_ids = list(range(df["episode_index"].max() + 1))
|
||||||
|
|||||||
@@ -36,6 +36,9 @@ from torchvision.ops.misc import FrozenBatchNorm2d
|
|||||||
|
|
||||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||||
|
from lerobot.common.policies.utils import (
|
||||||
|
populate_queues,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
@@ -73,6 +76,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||||
|
self._queues = None
|
||||||
|
|
||||||
self.model = ACT(config)
|
self.model = ACT(config)
|
||||||
|
|
||||||
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||||
@@ -84,7 +90,11 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
self._ensembled_actions = None
|
self._ensembled_actions = None
|
||||||
else:
|
else:
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._queues = {
|
||||||
|
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||||
|
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||||
|
"action": deque(maxlen=self.config.n_action_steps),
|
||||||
|
}
|
||||||
|
|
||||||
@torch.no_grad
|
@torch.no_grad
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
@@ -98,10 +108,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
|
|
||||||
batch = self.normalize_inputs(batch)
|
batch = self.normalize_inputs(batch)
|
||||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||||
|
# Note: It's important that this happens after stacking the images into a single key.
|
||||||
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||||
# the first action.
|
# the first action.
|
||||||
if self.config.temporal_ensemble_momentum is not None:
|
if self.config.temporal_ensemble_momentum is not None:
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
|
||||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||||
if self._ensembled_actions is None:
|
if self._ensembled_actions is None:
|
||||||
@@ -121,7 +136,10 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
|
|
||||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||||
# querying the policy.
|
# querying the policy.
|
||||||
if len(self._action_queue) == 0:
|
if len(self._queues["action"]) == 0:
|
||||||
|
# stack n latest observations from the queue
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
|
||||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||||
|
|
||||||
# TODO(rcadene): make _forward return output dictionary?
|
# TODO(rcadene): make _forward return output dictionary?
|
||||||
@@ -129,8 +147,8 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
|||||||
|
|
||||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||||
self._action_queue.extend(actions.transpose(0, 1))
|
self._queues["action"].extend(actions.transpose(0, 1))
|
||||||
return self._action_queue.popleft()
|
return self._queues["action"].popleft()
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
|
|||||||
Reference in New Issue
Block a user