Add direct access to action chunks (#1020)

* fix: sharing predicted chunk with user

* [pre-commit.ci] pre-commit autoupdate (#1011)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Revert "[pre-commit.ci] pre-commit autoupdate" (#1025)

* fix(ci): Pin draccus (<0.10.0) and torch (<2.7) to fix pipeline (#1022)

Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* fix(ci): Pin `torchcodec` (==0.2.1) to fix pipeline temporarly (#1030)

* Update tutorial (#1021)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* Add description motor order SO-101 leader (#1051)

* feat(encoding): switching to PyAV for ffmpeg related tasks (#983)

* feat(docs): Add new docs build process (#1046)

Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>

* Docs: adapt text + fix video code (#1064)

* Fix typos (#1070)

* docs: minor corrections and clean-up (#1089)

* Update 10_use_so100.md; use diff syntax (#944)

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* Update 12_use_so101.md (#1081)

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>

* bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073)

* Add editable -e for feetech install command (#1133)

* Fix: emptying action queue between resets (#1117)

* fix: typos and grammar (#1148)

* Update README.md (#1160)

* Update README.md (#1163)

* [Fix]  Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1  (#1127)

* (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182)

* [pre-commit.ci] pre-commit autoupdate (#1048)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>

* Add SmolVLA (#1175)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: fracapuano <francesco.capuano@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>

* Fix SmolVLA loss not sent to wandb (#1198)

* Hardware API redesign (#777)

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>

* fix(smolvla): update record.py, fix populate_queues and remove unused dependencies (#1208)

* replaced OBS_ROBOT with OBS_STATE constant (#1211)

* Fix test_teleoperate (#1216)

* Fix LeKiwi example (#1217)

* Fix smolVLA dependencies (#1218)

* fix(pyserial): adding pyserial dependency to global ones (#1219)

* Update SmolVLA README.md (#1228)

* Fix unable to set camera width/height to non-default (#1225)

* Update tutorial link (#1250)

* update KochFollower.get_observation() so it returns same observation structure as SO101 (#1248)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] pre-commit autoupdate (#1185)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* Proposal for fix for enter_pressed on Windows (#1230)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

* fix: update pi0 dependency version constraint (#1247)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Match motor names with ids lekiwi (#1261)

* fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256)

Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>

* fix(docs): update realsense documentation (#1268)

* Use HF Papers (#1120)

* Skip normalization parameters in load_smolvla (#1274)

* fix(record): no teleop needed when running with policy (#1284)

* Port HIL SERL (#644)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Eugene Mironov <helper2424@gmail.com>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: Ke Wang <superwk1017@gmail.com>
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>

* fix(docs): SmolVLA fine-tuning getting started (#1201)

Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Francesco Capuano <francesco_capuano@aol.com>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>

* chore(teleop): print calibration path saved (#1286)

* chore(dependencies): add gamepad support with pygame and hidapi (#1287)

* Robot integration tutorial (#1285)

* fix(docs): update send_feedback docstrings

* Add sim tutorial, fix lekiwi motor config, add notebook links (#1275)

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
Co-authored-by: Eugene Mironov <helper2424@gmail.com>
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>

* Fixes on robot integration tutorial (#1290)

* Add keyboard teleop device to control the end effector robot  (#1289)

* Improve type hints (#1293)

* fix(record): no teleop arg in reset environment (#1294)

* `learner.py` import so101_leader instead of so100 (#1295)

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>

* Fixing `PI0` Policy (#1297)

* `gym_manipulator.py` Remove None value action_intervention of BaseLeaderTeleoperator (#1299)

* (chore): incorrect resume parameter in recording documentation (#1301)

* Update lekiwi.mdx  (#1229)

* bump `pi0` and `hil` transformers version (#1298)

* docs: fix imitation learning robots docs command (#1308)

* fix(benchmarks): remove .numpy() from frame in benchmark script (#1354)

* add smolvla to the supported policies to run tests (:

* add: chunk-level access for the policy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add: smolvla in availables

* remove: smolvla from library supported policies

* fix: change env for training, xarm is broken as of now

* add: predict_action_chunk to all supported policies

* fix: add robot type constants

* add: predict action chunk in base policy class

* restore original Makefile

* fix: minor

* fix: dict keys come from lerobot/constants

* fix: improve act encapsulation, properly supporting temporal ensembling

* fix: smolvla action chunking

* fix: very minor, but very annoying

* fix: minor

* fix minor naming

Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

* fix: refactoring inference for single actions and chunks into different components

* fix: minor

* fix: temporal ensembling

* fix: moving populate queues out of modular component for batch preparation

* fix: minor for CI

* fix: smovla debug

* fix: reward classifier, maybe the last policy lacking?

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: imstevenpmwork <steven.palma@huggingface.co>
Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: omahs <73983677+omahs@users.noreply.github.com>
Co-authored-by: CharlesCNorton <135471798+CharlesCNorton@users.noreply.github.com>
Co-authored-by: masato-ka <jp6uzv@gmail.com>
Co-authored-by: Ragnar <rodiondenmark@gmail.com>
Co-authored-by: mshukor <mustafa.shukor97@gmail.com>
Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com>
Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: Ben Zhang <5977478+ben-z@users.noreply.github.com>
Co-authored-by: Pepijn <pepijn@huggingface.co>
Co-authored-by: Dhruva <51377003+utterwqlnut@users.noreply.github.com>
Co-authored-by: Daisuke Sato <tiryoh@gmail.com>
Co-authored-by: Sarunas Kalade <sarunas.kalade@amd.com>
Co-authored-by: koenvanwijk <koenvanwijk@users.noreply.github.com>
Co-authored-by: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com>
Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Eugene Mironov <helper2424@gmail.com>
Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com>
Co-authored-by: Ke Wang <superwk1017@gmail.com>
Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com>
Co-authored-by: tidely <43219534+tidely@users.noreply.github.com>
Co-authored-by: David <17435126+DavidLMS@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2025-06-27 10:19:19 +02:00
committed by GitHub
parent 0b2285d1ec
commit f3d931e1b2
11 changed files with 176 additions and 109 deletions

View File

@@ -25,6 +25,7 @@ ACTION = "action"
REWARD = "next.reward"
ROBOTS = "robots"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
# files & directories

View File

@@ -33,6 +33,7 @@ from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.constants import ACTION, OBS_IMAGES
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -114,46 +115,49 @@ class ACTPolicy(PreTrainedPolicy):
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
actions = self.predict_action_chunk(batch)
action = self.temporal_ensembler.update(actions)
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
actions = self.model(batch)[0][:, : self.config.n_action_steps]
# TODO(rcadene): make _forward return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
self.eval()
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
actions = self.model(batch)[0]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
l1_loss = (
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
).mean()
loss_dict = {"l1_loss": l1_loss.item()}

View File

@@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
@@ -99,6 +99,18 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -124,23 +136,15 @@ class DiffusionPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues["action"]) == 0:
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch)
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1))
# TODO(rcadene): make above methods return output dictionary?
actions = self.unnormalize_outputs({"action": actions})["action"]
self._queues["action"].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
@@ -148,9 +152,7 @@ class DiffusionPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None

View File

@@ -260,6 +260,11 @@ class PI0Policy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -192,6 +192,11 @@ class PI0FASTPolicy(PreTrainedPolicy):
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0FAST")
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.

View File

@@ -171,6 +171,15 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
raise NotImplementedError
@abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk
cached for selection.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).

View File

@@ -76,6 +76,11 @@ class SACPolicy(
"""Reset the policy"""
pass
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""

View File

@@ -308,6 +308,13 @@ class Classifier(PreTrainedPolicy):
"""
raise NotImplementedError("Reward classifiers do not select actions")
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.
The reward classifier is not an actor and does not produce action chunks.
"""
raise NotImplementedError("Reward classifiers do not predict action chunks")
def reset(self):
"""
This method is required by PreTrainedPolicy but not used for reward classifiers.

View File

@@ -383,6 +383,45 @@ class SmolVLAPolicy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
return actions
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
return batch
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
self.eval()
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
actions = self._get_action_chunk(batch, noise)
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
@@ -392,38 +431,18 @@ class SmolVLAPolicy(PreTrainedPolicy):
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch = self.normalize_inputs(batch)
batch = self._prepare_batch(batch)
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._queues[ACTION]) == 0:
for k in batch:
if k in self._queues:
batch[k] = torch.stack(list(self._queues[k]), dim=1)
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens, lang_masks = self.prepare_language(batch)
actions = self._get_action_chunk(batch, noise)
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({"action": actions})["action"]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
return self._queues[ACTION].popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:

View File

@@ -35,7 +35,7 @@ import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
@@ -110,52 +110,58 @@ class TDMPCPolicy(PreTrainedPolicy):
# CEM for the next step.
self._prev_mean: torch.Tensor | None = None
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
encode_keys = []
if self.config.image_features:
encode_keys.append(OBS_IMAGE)
if self.config.env_state_feature:
encode_keys.append(OBS_ENV_STATE)
encode_keys.append(OBS_STATE)
z = self.model.encode({k: batch[k] for k in encode_keys})
if self.config.use_mpc: # noqa: SIM108
actions = self.plan(z) # (horizon, batch, action_dim)
else:
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
# sequence dimension like in the MPC branch.
actions = self.model.pi(z).unsqueeze(0)
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
# When the action queue is depleted, populate it again by querying the policy.
if len(self._queues["action"]) == 0:
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
# Remove the time dimensions as it is not handled yet.
for key in batch:
assert batch[key].shape[1] == 1
batch[key] = batch[key][:, 0]
# NOTE: Order of observations matters here.
encode_keys = []
if self.config.image_features:
encode_keys.append("observation.image")
if self.config.env_state_feature:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
if self.config.use_mpc: # noqa: SIM108
actions = self.plan(z) # (horizon, batch, action_dim)
else:
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
# sequence dimension like in the MPC branch.
actions = self.model.pi(z).unsqueeze(0)
actions = torch.clamp(actions, -1, +1)
actions = self.unnormalize_outputs({"action": actions})["action"]
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
if self.config.n_action_repeats > 1:
for _ in range(self.config.n_action_repeats):
self._queues["action"].append(actions[0])
self._queues[ACTION].append(actions[0])
else:
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
self._queues["action"].extend(actions[: self.config.n_action_steps])
self._queues[ACTION].extend(actions[: self.config.n_action_steps])
action = self._queues["action"].popleft()
action = self._queues[ACTION].popleft()
return action
@torch.no_grad()
@@ -312,7 +318,7 @@ class TDMPCPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
@@ -322,15 +328,15 @@ class TDMPCPolicy(PreTrainedPolicy):
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
action = batch["action"] # (t, b, action_dim)
reward = batch["next.reward"] # (t, b)
action = batch[ACTION] # (t, b, action_dim)
reward = batch[REWARD] # (t, b)
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
observations[OBS_IMAGE] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
observations[OBS_IMAGE],
)
# Get the current observation for predicting trajectories, and all future observations for use in
@@ -340,7 +346,7 @@ class TDMPCPolicy(PreTrainedPolicy):
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self.config.image_features else "observation.environment_state"
OBS_IMAGE if self.config.image_features else OBS_ENV_STATE
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.

View File

@@ -27,6 +27,7 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from torch import Tensor, nn
from lerobot.common.constants import ACTION, OBS_IMAGES, OBS_STATE
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
@@ -118,11 +119,18 @@ class VQBeTPolicy(PreTrainedPolicy):
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.images": deque(maxlen=self.config.n_obs_steps),
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.action_chunk_size),
OBS_IMAGES: deque(maxlen=self.config.n_obs_steps),
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
ACTION: deque(maxlen=self.config.action_chunk_size),
}
@torch.no_grad
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
return actions
@torch.no_grad
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
@@ -144,23 +152,19 @@ class VQBeTPolicy(PreTrainedPolicy):
stacklevel=1,
)
if len(self._queues["action"]) == 0:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
self._queues["action"].extend(actions.transpose(0, 1))
self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues["action"].popleft()
action = self._queues[ACTION].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -168,7 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy):
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION])
)
return loss, {
"n_different_codes": n_different_codes,
@@ -404,7 +408,7 @@ class VQBeTModel(nn.Module):
)
# else, it calculate overall loss (bin prediction loss, and offset loss)
else:
output = batch["action"][:, self.select_target_actions_indices]
output = batch[ACTION][:, self.select_target_actions_indices]
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
return action_head_output, loss