Add direct access to action chunks (#1020)
* fix: sharing predicted chunk with user * [pre-commit.ci] pre-commit autoupdate (#1011) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Revert "[pre-commit.ci] pre-commit autoupdate" (#1025) * fix(ci): Pin draccus (<0.10.0) and torch (<2.7) to fix pipeline (#1022) Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * fix(ci): Pin `torchcodec` (==0.2.1) to fix pipeline temporarly (#1030) * Update tutorial (#1021) Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * Add description motor order SO-101 leader (#1051) * feat(encoding): switching to PyAV for ffmpeg related tasks (#983) * feat(docs): Add new docs build process (#1046) Co-authored-by: Mishig Davaadorj <dmishig@gmail.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> * Docs: adapt text + fix video code (#1064) * Fix typos (#1070) * docs: minor corrections and clean-up (#1089) * Update 10_use_so100.md; use diff syntax (#944) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * Update 12_use_so101.md (#1081) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> * bug fix for #1071 When --display_data=true, Failed running control_robot. (#1073) * Add editable -e for feetech install command (#1133) * Fix: emptying action queue between resets (#1117) * fix: typos and grammar (#1148) * Update README.md (#1160) * Update README.md (#1163) * [Fix] Unpin torch beyond 2.6.0 & torchcodec beyond 0.2.1 (#1127) * (hotfix): nightly CI by clipping pymunk version below 7.0.0 (#1182) * [pre-commit.ci] pre-commit autoupdate (#1048) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> * Add SmolVLA (#1175) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: fracapuano <francesco.capuano@huggingface.co> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com> Co-authored-by: Remi <remi.cadene@huggingface.co> * Fix SmolVLA loss not sent to wandb (#1198) * Hardware API redesign (#777) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: Pepijn <pepijn@huggingface.co> * fix(smolvla): update record.py, fix populate_queues and remove unused dependencies (#1208) * replaced OBS_ROBOT with OBS_STATE constant (#1211) * Fix test_teleoperate (#1216) * Fix LeKiwi example (#1217) * Fix smolVLA dependencies (#1218) * fix(pyserial): adding pyserial dependency to global ones (#1219) * Update SmolVLA README.md (#1228) * Fix unable to set camera width/height to non-default (#1225) * Update tutorial link (#1250) * update KochFollower.get_observation() so it returns same observation structure as SO101 (#1248) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] pre-commit autoupdate (#1185) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * Proposal for fix for enter_pressed on Windows (#1230) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> * fix: update pi0 dependency version constraint (#1247) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Match motor names with ids lekiwi (#1261) * fix issues: checkpoints keys mismatch and 'task' tokenisation in smolvla (#1256) Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> * fix(docs): update realsense documentation (#1268) * Use HF Papers (#1120) * Skip normalization parameters in load_smolvla (#1274) * fix(record): no teleop needed when running with policy (#1284) * Port HIL SERL (#644) Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Ke Wang <superwk1017@gmail.com> Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> * fix(docs): SmolVLA fine-tuning getting started (#1201) Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Francesco Capuano <francesco_capuano@aol.com> Co-authored-by: Steven Palma <steven.palma@huggingface.co> * chore(teleop): print calibration path saved (#1286) * chore(dependencies): add gamepad support with pygame and hidapi (#1287) * Robot integration tutorial (#1285) * fix(docs): update send_feedback docstrings * Add sim tutorial, fix lekiwi motor config, add notebook links (#1275) Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> * Fixes on robot integration tutorial (#1290) * Add keyboard teleop device to control the end effector robot (#1289) * Improve type hints (#1293) * fix(record): no teleop arg in reset environment (#1294) * `learner.py` import so101_leader instead of so100 (#1295) Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> * Fixing `PI0` Policy (#1297) * `gym_manipulator.py` Remove None value action_intervention of BaseLeaderTeleoperator (#1299) * (chore): incorrect resume parameter in recording documentation (#1301) * Update lekiwi.mdx (#1229) * bump `pi0` and `hil` transformers version (#1298) * docs: fix imitation learning robots docs command (#1308) * fix(benchmarks): remove .numpy() from frame in benchmark script (#1354) * add smolvla to the supported policies to run tests (: * add: chunk-level access for the policy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add: smolvla in availables * remove: smolvla from library supported policies * fix: change env for training, xarm is broken as of now * add: predict_action_chunk to all supported policies * fix: add robot type constants * add: predict action chunk in base policy class * restore original Makefile * fix: minor * fix: dict keys come from lerobot/constants * fix: improve act encapsulation, properly supporting temporal ensembling * fix: smolvla action chunking * fix: very minor, but very annoying * fix: minor * fix minor naming Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> * fix: refactoring inference for single actions and chunks into different components * fix: minor * fix: temporal ensembling * fix: moving populate queues out of modular component for batch preparation * fix: minor for CI * fix: smovla debug * fix: reward classifier, maybe the last policy lacking? --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com> Co-authored-by: imstevenpmwork <steven.palma@huggingface.co> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> Co-authored-by: Mishig Davaadorj <dmishig@gmail.com> Co-authored-by: omahs <73983677+omahs@users.noreply.github.com> Co-authored-by: CharlesCNorton <135471798+CharlesCNorton@users.noreply.github.com> Co-authored-by: masato-ka <jp6uzv@gmail.com> Co-authored-by: Ragnar <rodiondenmark@gmail.com> Co-authored-by: mshukor <mustafa.shukor97@gmail.com> Co-authored-by: Simon Alibert <simon.alibert@huggingface.co> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org> Co-authored-by: Dana Aubakirova <118912928+danaaubakirova@users.noreply.github.com> Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: Ben Zhang <5977478+ben-z@users.noreply.github.com> Co-authored-by: Pepijn <pepijn@huggingface.co> Co-authored-by: Dhruva <51377003+utterwqlnut@users.noreply.github.com> Co-authored-by: Daisuke Sato <tiryoh@gmail.com> Co-authored-by: Sarunas Kalade <sarunas.kalade@amd.com> Co-authored-by: koenvanwijk <koenvanwijk@users.noreply.github.com> Co-authored-by: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com> Co-authored-by: danaaubakirova <d.aubakirova@alumni.edu.kz> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Eugene Mironov <helper2424@gmail.com> Co-authored-by: s1lent4gnt <kmeftah.khalil@gmail.com> Co-authored-by: Ke Wang <superwk1017@gmail.com> Co-authored-by: Yoel Chornton <yoel.chornton@gmail.com> Co-authored-by: Michel Aractingi <michel.aractingi@gmail.com> Co-authored-by: tidely <43219534+tidely@users.noreply.github.com> Co-authored-by: David <17435126+DavidLMS@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0b2285d1ec
commit
f3d931e1b2
@@ -25,6 +25,7 @@ ACTION = "action"
|
||||
REWARD = "next.reward"
|
||||
|
||||
ROBOTS = "robots"
|
||||
ROBOT_TYPE = "robot_type"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
# files & directories
|
||||
|
||||
@@ -33,6 +33,7 @@ from torch import Tensor, nn
|
||||
from torchvision.models._utils import IntermediateLayerGetter
|
||||
from torchvision.ops.misc import FrozenBatchNorm2d
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGES
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -114,46 +115,49 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
if self.config.temporal_ensemble_coeff is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
actions = self.predict_action_chunk(batch)
|
||||
action = self.temporal_ensembler.update(actions)
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
# TODO(rcadene): make _forward return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps]
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
actions = self.model(batch)[0]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -99,6 +99,18 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
@@ -124,23 +136,15 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
# Note: It's important that this happens after stacking the images into a single key.
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch)
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
# TODO(rcadene): make above methods return output dictionary?
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
@@ -148,9 +152,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
|
||||
@@ -260,6 +260,11 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("Currently not implemented for PI0")
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
@@ -192,6 +192,11 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("Currently not implemented for PI0FAST")
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
@@ -171,6 +171,15 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
|
||||
|
||||
Child classes using action chunking should use this method within `select_action` to form the action chunk
|
||||
cached for selection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
@@ -76,6 +76,11 @@ class SACPolicy(
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!")
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
|
||||
@@ -308,6 +308,13 @@ class Classifier(PreTrainedPolicy):
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not select actions")
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
The reward classifier is not an actor and does not produce action chunks.
|
||||
"""
|
||||
raise NotImplementedError("Reward classifiers do not predict action chunks")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
This method is required by PreTrainedPolicy but not used for reward classifiers.
|
||||
|
||||
@@ -383,6 +383,45 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
return actions
|
||||
|
||||
def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
return batch
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
@@ -392,38 +431,18 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
queue is empty.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
for k in batch:
|
||||
if k in self._queues:
|
||||
batch[k] = torch.stack(list(self._queues[k]), dim=1)
|
||||
images, img_masks = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
lang_tokens, lang_masks = self.prepare_language(batch)
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise
|
||||
)
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
actions = self._pi_aloha_encode_actions(actions)
|
||||
|
||||
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
|
||||
@@ -35,7 +35,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
@@ -110,52 +110,58 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self.config.image_features:
|
||||
encode_keys.append(OBS_IMAGE)
|
||||
if self.config.env_state_feature:
|
||||
encode_keys.append(OBS_ENV_STATE)
|
||||
encode_keys.append(OBS_STATE)
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z).unsqueeze(0)
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self.config.image_features:
|
||||
encode_keys.append("observation.image")
|
||||
if self.config.env_state_feature:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z).unsqueeze(0)
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
self._queues[ACTION].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
self._queues[ACTION].extend(actions[: self.config.n_action_steps])
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -312,7 +318,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[next(iter(self.config.image_features))]
|
||||
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
@@ -322,15 +328,15 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
action = batch[ACTION] # (t, b, action_dim)
|
||||
reward = batch[REWARD] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
observations[OBS_IMAGE] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
observations[OBS_IMAGE],
|
||||
)
|
||||
|
||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||
@@ -340,7 +346,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self.config.image_features else "observation.environment_state"
|
||||
OBS_IMAGE if self.config.image_features else OBS_ENV_STATE
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
|
||||
@@ -27,6 +27,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
import torchvision
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
@@ -118,11 +119,18 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.images": deque(maxlen=self.config.n_obs_steps),
|
||||
"observation.state": deque(maxlen=self.config.n_obs_steps),
|
||||
"action": deque(maxlen=self.config.action_chunk_size),
|
||||
OBS_IMAGES: deque(maxlen=self.config.n_obs_steps),
|
||||
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
|
||||
ACTION: deque(maxlen=self.config.action_chunk_size),
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
@@ -144,23 +152,19 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size]
|
||||
|
||||
# the dimension of returned action is (batch_size, action_chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
# since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue
|
||||
self._queues["action"].extend(actions.transpose(0, 1))
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
action = self._queues[ACTION].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
batch = self.normalize_targets(batch)
|
||||
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181)
|
||||
if not self.vqbet.action_head.vqvae_model.discretized.item():
|
||||
@@ -168,7 +172,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
# n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`.
|
||||
# n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree).
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION])
|
||||
)
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
@@ -404,7 +408,7 @@ class VQBeTModel(nn.Module):
|
||||
)
|
||||
# else, it calculate overall loss (bin prediction loss, and offset loss)
|
||||
else:
|
||||
output = batch["action"][:, self.select_target_actions_indices]
|
||||
output = batch[ACTION][:, self.select_target_actions_indices]
|
||||
loss = self.action_head.loss_fn(action_head_output, output, reduction="mean")
|
||||
return action_head_output, loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user