feat(policies): add noise parameter to action prediction methods (#2063)

* feat(policies): add noise parameter to action prediction methods

- Introduced `ActionSelectKwargs` TypedDict for better type hinting.
- Updated `predict_action_chunk` and `select_action` methods in `PreTrainedPolicy` and its subclasses to accept a `noise` parameter.
- Modified `generate_actions` and `conditional_sample` methods in `DiffusionModel` to utilize the new noise parameter for action generation.

* refactor(policies): make ActionSelectKwargs TypedDict fields optional

- Updated `ActionSelectKwargs` to inherit with `total=False`, allowing for optional fields.
This commit is contained in:
Adil Zouitine
2025-09-29 17:02:19 +02:00
committed by GitHub
parent 2d3a605b3c
commit 1ad2da403d
3 changed files with 29 additions and 16 deletions

View File

@@ -90,16 +90,16 @@ class DiffusionPolicy(PreTrainedPolicy):
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps) self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue # stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch) actions = self.diffusion.generate_actions(batch, noise=noise)
return actions return actions
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations. """Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the This method handles caching a history of observations and an action trajectory generated by the
@@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy):
self._queues = populate_queues(self._queues, batch) self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0: if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch) actions = self.predict_action_chunk(batch, noise=noise)
self._queues[ACTION].extend(actions.transpose(0, 1)) self._queues[ACTION].extend(actions.transpose(0, 1))
action = self._queues[ACTION].popleft() action = self._queues[ACTION].popleft()
@@ -199,17 +199,25 @@ class DiffusionModel(nn.Module):
# ========= inference ============ # ========= inference ============
def conditional_sample( def conditional_sample(
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None self,
batch_size: int,
global_cond: Tensor | None = None,
generator: torch.Generator | None = None,
noise: Tensor | None = None,
) -> Tensor: ) -> Tensor:
device = get_device_from_parameters(self) device = get_device_from_parameters(self)
dtype = get_dtype_from_parameters(self) dtype = get_dtype_from_parameters(self)
# Sample prior. # Sample prior.
sample = torch.randn( sample = (
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]), noise
dtype=dtype, if noise is not None
device=device, else torch.randn(
generator=generator, size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
)
) )
self.noise_scheduler.set_timesteps(self.num_inference_steps) self.noise_scheduler.set_timesteps(self.num_inference_steps)
@@ -264,7 +272,7 @@ class DiffusionModel(nn.Module):
# Concatenate features then flatten to (B, global_cond_dim). # Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor: def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
""" """
This function expects `batch` to have: This function expects `batch` to have:
{ {
@@ -282,7 +290,7 @@ class DiffusionModel(nn.Module):
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim) global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# run sampling # run sampling
actions = self.conditional_sample(batch_size, global_cond=global_cond) actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
# Extract `n_action_steps` steps worth of actions (from the current observation). # Extract `n_action_steps` steps worth of actions (from the current observation).
start = n_obs_steps - 1 start = n_obs_steps - 1

View File

@@ -253,7 +253,7 @@ class PI0Policy(PreTrainedPolicy):
return super().from_pretrained(*args, **kwargs) return super().from_pretrained(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0") raise NotImplementedError("Currently not implemented for PI0")

View File

@@ -18,7 +18,7 @@ import os
from importlib.resources import files from importlib.resources import files
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import TypeVar from typing import TypedDict, TypeVar
import packaging import packaging
import safetensors import safetensors
@@ -27,6 +27,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
from torch import Tensor, nn from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
@@ -36,6 +37,10 @@ from lerobot.utils.hub import HubMixin
T = TypeVar("T", bound="PreTrainedPolicy") T = TypeVar("T", bound="PreTrainedPolicy")
class ActionSelectKwargs(TypedDict, total=False):
noise: Tensor | None
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
""" """
Base class for policy models. Base class for policy models.
@@ -181,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
Child classes using action chunking should use this method within `select_action` to form the action chunk Child classes using action chunking should use this method within `select_action` to form the action chunk
@@ -190,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode). """Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals When the model uses a history of observations, or outputs a sequence of actions, this method deals