From 1ad2da403d5526c5d8933d805c726acc98ec561b Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 29 Sep 2025 17:02:19 +0200 Subject: [PATCH] feat(policies): add noise parameter to action prediction methods (#2063) * feat(policies): add noise parameter to action prediction methods - Introduced `ActionSelectKwargs` TypedDict for better type hinting. - Updated `predict_action_chunk` and `select_action` methods in `PreTrainedPolicy` and its subclasses to accept a `noise` parameter. - Modified `generate_actions` and `conditional_sample` methods in `DiffusionModel` to utilize the new noise parameter for action generation. * refactor(policies): make ActionSelectKwargs TypedDict fields optional - Updated `ActionSelectKwargs` to inherit with `total=False`, allowing for optional fields. --- .../policies/diffusion/modeling_diffusion.py | 32 ++++++++++++------- src/lerobot/policies/pi0/modeling_pi0.py | 2 +- src/lerobot/policies/pretrained.py | 11 +++++-- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index ad808d7c7..3ab6719cb 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -90,16 +90,16 @@ class DiffusionPolicy(PreTrainedPolicy): self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Predict a chunk of actions given environment observations.""" # stack n latest observations from the queue batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch) + actions = self.diffusion.generate_actions(batch, noise=noise) return actions @torch.no_grad() - def select_action(self, batch: dict[str, Tensor]) -> Tensor: + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. This method handles caching a history of observations and an action trajectory generated by the @@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy): self._queues = populate_queues(self._queues, batch) if len(self._queues[ACTION]) == 0: - actions = self.predict_action_chunk(batch) + actions = self.predict_action_chunk(batch, noise=noise) self._queues[ACTION].extend(actions.transpose(0, 1)) action = self._queues[ACTION].popleft() @@ -199,17 +199,25 @@ class DiffusionModel(nn.Module): # ========= inference ============ def conditional_sample( - self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + self, + batch_size: int, + global_cond: Tensor | None = None, + generator: torch.Generator | None = None, + noise: Tensor | None = None, ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) # Sample prior. - sample = torch.randn( - size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), - dtype=dtype, - device=device, - generator=generator, + sample = ( + noise + if noise is not None + else torch.randn( + size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), + dtype=dtype, + device=device, + generator=generator, + ) ) self.noise_scheduler.set_timesteps(self.num_inference_steps) @@ -264,7 +272,7 @@ class DiffusionModel(nn.Module): # Concatenate features then flatten to (B, global_cond_dim). return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) - def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: + def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """ This function expects `batch` to have: { @@ -282,7 +290,7 @@ class DiffusionModel(nn.Module): global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) # run sampling - actions = self.conditional_sample(batch_size, global_cond=global_cond) + actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise) # Extract `n_action_steps` steps worth of actions (from the current observation). start = n_obs_steps - 1 diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 4d3f4ffa1..8406f94fe 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -253,7 +253,7 @@ class PI0Policy(PreTrainedPolicy): return super().from_pretrained(*args, **kwargs) @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Predict a chunk of actions given environment observations.""" raise NotImplementedError("Currently not implemented for PI0") diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index b770c980b..3f5d89ec5 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -18,7 +18,7 @@ import os from importlib.resources import files from pathlib import Path from tempfile import TemporaryDirectory -from typing import TypeVar +from typing import TypedDict, TypeVar import packaging import safetensors @@ -27,6 +27,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from torch import Tensor, nn +from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig @@ -36,6 +37,10 @@ from lerobot.utils.hub import HubMixin T = TypeVar("T", bound="PreTrainedPolicy") +class ActionSelectKwargs(TypedDict, total=False): + noise: Tensor | None + + class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ Base class for policy models. @@ -181,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): raise NotImplementedError @abc.abstractmethod - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. Child classes using action chunking should use this method within `select_action` to form the action chunk @@ -190,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): raise NotImplementedError @abc.abstractmethod - def select_action(self, batch: dict[str, Tensor]) -> Tensor: + def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). When the model uses a history of observations, or outputs a sequence of actions, this method deals