From f3d931e1b2e4b5b68a988c413fcc74825c02dd01 Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Fri, 27 Jun 2025 10:19:19 +0200 Subject: [PATCH] Add direct access to action chunks (#1020) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: sharing predicted chunk with user * [pre-commit.ci] pre-commit autoupdate (#1011) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Revert "[pre-commit.ci] pre-commit autoupdate" (#1025) * fix(ci): Pin draccus (<0.10.0) and torch (<2.7) to fix pipeline (#1022) Co-authored-by: imstevenpmwork Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * fix(ci): Pin `torchcodec` (==0.2.1) to fix pipeline temporarly (#1030) * Update tutorial (#1021) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * Add description motor order SO-101 leader (#1051) * feat(encoding): switching to PyAV for ffmpeg related tasks (#983) * feat(docs): Add new docs build process (#1046) Co-authored-by: Mishig Davaadorj Co-authored-by: Steven Palma * Docs: adapt text + fix video code (#1064) * Fix typos (#1070) * docs: minor corrections and clean-up (#1089) * Update 10_use_so100.md; use diff syntax (#944) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * Update 12_use_so101.md (#1081) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073) * Add editable -e for feetech install command (#1133) * Fix: emptying action queue between resets (#1117) * fix: typos and grammar (#1148) * Update README.md (#1160) * Update README.md (#1163) * [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) * (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182) * [pre-commit.ci] pre-commit autoupdate (#1048) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert * Add SmolVLA (#1175) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: fracapuano Co-authored-by: Steven Palma Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com> Co-authored-by: Remi * Fix SmolVLA loss not sent to wandb (#1198) * Hardware API redesign (#777) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma Co-authored-by: Adil Zouitine Co-authored-by: Pepijn * fix(smolvla): update record.py, fix populate_queues and remove unused dependencies (#1208) * replaced OBS_ROBOT with OBS_STATE constant (#1211) * Fix test_teleoperate (#1216) * Fix LeKiwi example (#1217) * Fix smolVLA dependencies (#1218) * fix(pyserial): adding pyserial dependency to global ones (#1219) * Update SmolVLA README.md (#1228) * Fix unable to set camera width/height to non-default (#1225) * Update tutorial link (#1250) * update KochFollower.get_observation() so it returns same observation structure as SO101 (#1248) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit autoupdate (#1185) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * Proposal for fix for enter_pressed on Windows (#1230) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * fix: update pi0 dependency version constraint (#1247) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Match motor names with ids lekiwi (#1261) * fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256) Co-authored-by: danaaubakirova Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Simon Alibert * fix(docs): update realsense documentation (#1268) * Use HF Papers (#1120) * Skip normalization parameters in load_smolvla (#1274) * fix(record): no teleop needed when running with policy (#1284) * Port HIL SERL (#644) Co-authored-by: Michel Aractingi Co-authored-by: Eugene Mironov Co-authored-by: s1lent4gnt Co-authored-by: Ke Wang Co-authored-by: Yoel Chornton Co-authored-by: imstevenpmwork Co-authored-by: Simon Alibert * fix(docs): SmolVLA fine-tuning getting started (#1201) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: danaaubakirova Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Francesco Capuano Co-authored-by: Steven Palma * chore(teleop): print calibration path saved (#1286) * chore(dependencies): add gamepad support with pygame and hidapi (#1287) * Robot integration tutorial (#1285) * fix(docs): update send_feedback docstrings * Add sim tutorial, fix lekiwi motor config, add notebook links (#1275) Co-authored-by: AdilZouitine Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi Co-authored-by: s1lent4gnt Co-authored-by: Michel Aractingi Co-authored-by: Eugene Mironov Co-authored-by: imstevenpmwork Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Steven Palma * Fixes on robot integration tutorial (#1290) * Add keyboard teleop device to control the end effector robot (#1289) * Improve type hints (#1293) * fix(record): no teleop arg in reset environment (#1294) * `learner.py` import so101_leader instead of so100 (#1295) Co-authored-by: Adil Zouitine * Fixing `PI0` Policy (#1297) * `gym_manipulator.py` Remove None value action_intervention of BaseLeaderTeleoperator (#1299) * (chore): incorrect resume parameter in recording documentation (#1301) * Update lekiwi.mdx (#1229) * bump `pi0` and `hil` transformers version (#1298) * docs: fix imitation learning robots docs command (#1308) * fix(benchmarks): remove .numpy() from frame in benchmark script (#1354) * add smolvla to the supported policies to run tests (: * add: chunk-level access for the policy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add: smolvla in availables * remove: smolvla from library supported policies * fix: change env for training, xarm is broken as of now * add: predict_action_chunk to all supported policies * fix: add robot type constants * add: predict action chunk in base policy class * restore original Makefile * fix: minor * fix: dict keys come from lerobot/constants * fix: improve act encapsulation, properly supporting temporal ensembling * fix: smolvla action chunking * fix: very minor, but very annoying * fix: minor * fix minor naming Co-authored-by: Steven Palma Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * fix: refactoring inference for single actions and chunks into different components * fix: minor * fix: temporal ensembling * fix: moving populate queues out of modular component for batch preparation * fix: minor for CI * fix: smovla debug * fix: reward classifier, maybe the last policy lacking? --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Adil Zouitine Co-authored-by: imstevenpmwork Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Caroline Pascal Co-authored-by: Mishig Davaadorj Co-authored-by: omahs <73983677+omahs@users.noreply.github.com> Co-authored-by: CharlesCNorton <135471798+CharlesCNorton@users.noreply.github.com> Co-authored-by: masato-ka Co-authored-by: Ragnar Co-authored-by: mshukor Co-authored-by: Simon Alibert Co-authored-by: Steven Palma Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com> Co-authored-by: Remi Co-authored-by: Ben Zhang <5977478+ben-z@users.noreply.github.com> Co-authored-by: Pepijn Co-authored-by: Dhruva <51377003+utterwqlnut@users.noreply.github.com> Co-authored-by: Daisuke Sato Co-authored-by: Sarunas Kalade Co-authored-by: koenvanwijk Co-authored-by: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com> Co-authored-by: danaaubakirova Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Michel Aractingi Co-authored-by: Eugene Mironov Co-authored-by: s1lent4gnt Co-authored-by: Ke Wang Co-authored-by: Yoel Chornton Co-authored-by: Michel Aractingi Co-authored-by: tidely <43219534+tidely@users.noreply.github.com> Co-authored-by: David <17435126+DavidLMS@users.noreply.github.com> --- lerobot/common/constants.py | 1 + lerobot/common/policies/act/modeling_act.py | 36 +++++---- .../policies/diffusion/modeling_diffusion.py | 36 +++++---- lerobot/common/policies/pi0/modeling_pi0.py | 5 ++ .../policies/pi0fast/modeling_pi0fast.py | 5 ++ lerobot/common/policies/pretrained.py | 9 +++ lerobot/common/policies/sac/modeling_sac.py | 5 ++ .../sac/reward_model/modeling_classifier.py | 7 ++ .../policies/smolvla/modeling_smolvla.py | 69 ++++++++++------ .../common/policies/tdmpc/modeling_tdmpc.py | 80 ++++++++++--------- .../common/policies/vqbet/modeling_vqbet.py | 32 ++++---- 11 files changed, 176 insertions(+), 109 deletions(-) diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 990f2aa1..30777239 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -25,6 +25,7 @@ ACTION = "action" REWARD = "next.reward" ROBOTS = "robots" +ROBOT_TYPE = "robot_type" TELEOPERATORS = "teleoperators" # files & directories diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index e7e74bf3..12206657 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -33,6 +33,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d +from lerobot.common.constants import ACTION, OBS_IMAGES from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -114,46 +115,49 @@ class ACTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - self.eval() + self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed - batch = self.normalize_inputs(batch) - if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] - - # If we are doing temporal ensembling, do online updates where we keep track of the number of actions - # we are ensembling over. if self.config.temporal_ensemble_coeff is not None: - actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch) action = self.temporal_ensembler.update(actions) return action # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: - actions = self.model(batch)[0][:, : self.config.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] + + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 446e2cb6..038136d0 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -99,6 +99,18 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.env_state_feature: self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.diffusion.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -124,23 +136,15 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch) + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) - # TODO(rcadene): make above methods return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] - - self._queues["action"].extend(actions.transpose(0, 1)) - - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: @@ -148,9 +152,7 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 1d8a5055..97e66a27 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -260,6 +260,11 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0") + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 7102bdde..dbf5266b 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -192,6 +192,11 @@ class PI0FASTPolicy(PreTrainedPolicy): actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0FAST") + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index 58eef9ba..bc9276d0 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -171,6 +171,15 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ raise NotImplementedError + @abc.abstractmethod + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. + + Child classes using action chunking should use this method within `select_action` to form the action chunk + cached for selection. + """ + raise NotImplementedError + @abc.abstractmethod def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index b588115e..1ca46935 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -76,6 +76,11 @@ class SACPolicy( """Reset the policy""" pass + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!") + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" diff --git a/lerobot/common/policies/sac/reward_model/modeling_classifier.py b/lerobot/common/policies/sac/reward_model/modeling_classifier.py index f537e3ae..7fec67f1 100644 --- a/lerobot/common/policies/sac/reward_model/modeling_classifier.py +++ b/lerobot/common/policies/sac/reward_model/modeling_classifier.py @@ -308,6 +308,13 @@ class Classifier(PreTrainedPolicy): """ raise NotImplementedError("Reward classifiers do not select actions") + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not produce action chunks. + """ + raise NotImplementedError("Reward classifiers do not predict action chunks") + def reset(self): """ This method is required by PreTrainedPolicy but not used for reward classifiers. diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index 5e0a9622..36199984 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -383,6 +383,45 @@ class SmolVLAPolicy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) + + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + return actions + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + return batch + + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + self.eval() + + batch = self._prepare_batch(batch) + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + actions = self._get_action_chunk(batch, noise) + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. @@ -392,38 +431,18 @@ class SmolVLAPolicy(PreTrainedPolicy): queue is empty. """ self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - + batch = self._prepare_batch(batch) self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._queues[ACTION]) == 0: - for k in batch: - if k in self._queues: - batch[k] = torch.stack(list(self._queues[k]), dim=1) - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens, lang_masks = self.prepare_language(batch) + actions = self._get_action_chunk(batch, noise) - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index 476e6dec..4bb564f8 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -35,7 +35,7 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -110,52 +110,58 @@ class TDMPCPolicy(PreTrainedPolicy): # CEM for the next step. self._prev_mean: torch.Tensor | None = None + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + encode_keys = [] + if self.config.image_features: + encode_keys.append(OBS_IMAGE) + if self.config.env_state_feature: + encode_keys.append(OBS_ENV_STATE) + encode_keys.append(OBS_STATE) + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) + else: + # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) + + actions = torch.clamp(actions, -1, +1) + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) # When the action queue is depleted, populate it again by querying the policy. - if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} - - # Remove the time dimensions as it is not handled yet. - for key in batch: - assert batch[key].shape[1] == 1 - batch[key] = batch[key][:, 0] - - # NOTE: Order of observations matters here. - encode_keys = [] - if self.config.image_features: - encode_keys.append("observation.image") - if self.config.env_state_feature: - encode_keys.append("observation.environment_state") - encode_keys.append("observation.state") - z = self.model.encode({k: batch[k] for k in encode_keys}) - if self.config.use_mpc: # noqa: SIM108 - actions = self.plan(z) # (horizon, batch, action_dim) - else: - # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a - # sequence dimension like in the MPC branch. - actions = self.model.pi(z).unsqueeze(0) - - actions = torch.clamp(actions, -1, +1) - - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) if self.config.n_action_repeats > 1: for _ in range(self.config.n_action_repeats): - self._queues["action"].append(actions[0]) + self._queues[ACTION].append(actions[0]) else: # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. - self._queues["action"].extend(actions[: self.config.n_action_steps]) + self._queues[ACTION].extend(actions[: self.config.n_action_steps]) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action @torch.no_grad() @@ -312,7 +318,7 @@ class TDMPCPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) info = {} @@ -322,15 +328,15 @@ class TDMPCPolicy(PreTrainedPolicy): if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - action = batch["action"] # (t, b, action_dim) - reward = batch["next.reward"] # (t, b) + action = batch[ACTION] # (t, b, action_dim) + reward = batch[REWARD] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: - observations["observation.image"] = flatten_forward_unflatten( + observations[OBS_IMAGE] = flatten_forward_unflatten( partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), - observations["observation.image"], + observations[OBS_IMAGE], ) # Get the current observation for predicting trajectories, and all future observations for use in @@ -340,7 +346,7 @@ class TDMPCPolicy(PreTrainedPolicy): current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] horizon, batch_size = next_observations[ - "observation.image" if self.config.image_features else "observation.environment_state" + OBS_IMAGE if self.config.image_features else OBS_ENV_STATE ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 44006a5b..a76bea2a 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -27,6 +27,7 @@ import torch.nn.functional as F # noqa: N812 import torchvision from torch import Tensor, nn +from lerobot.common.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -118,11 +119,18 @@ class VQBeTPolicy(PreTrainedPolicy): queues are populated during rollout of the policy, they contain the n latest observations and actions """ self._queues = { - "observation.images": deque(maxlen=self.config.n_obs_steps), - "observation.state": deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.action_chunk_size), + OBS_IMAGES: deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.action_chunk_size), } + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -144,23 +152,19 @@ class VQBeTPolicy(PreTrainedPolicy): stacklevel=1, ) - if len(self._queues["action"]) == 0: - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] - - # the dimension of returned action is (batch_size, action_chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue - self._queues["action"].extend(actions.transpose(0, 1)) + self._queues[ACTION].extend(actions.transpose(0, 1)) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -168,7 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy): # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION]) ) return loss, { "n_different_codes": n_different_codes, @@ -404,7 +408,7 @@ class VQBeTModel(nn.Module): ) # else, it calculate overall loss (bin prediction loss, and offset loss) else: - output = batch["action"][:, self.select_target_actions_indices] + output = batch[ACTION][:, self.select_target_actions_indices] loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") return action_head_output, loss