feat(policies): add noise parameter to action prediction methods (#2063)
* feat(policies): add noise parameter to action prediction methods - Introduced `ActionSelectKwargs` TypedDict for better type hinting. - Updated `predict_action_chunk` and `select_action` methods in `PreTrainedPolicy` and its subclasses to accept a `noise` parameter. - Modified `generate_actions` and `conditional_sample` methods in `DiffusionModel` to utilize the new noise parameter for action generation. * refactor(policies): make ActionSelectKwargs TypedDict fields optional - Updated `ActionSelectKwargs` to inherit with `total=False`, allowing for optional fields.
This commit is contained in:
@@ -90,16 +90,16 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
self._queues[OBS_ENV_STATE] = deque(maxlen=self.config.n_obs_steps)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
# stack n latest observations from the queue
|
# stack n latest observations from the queue
|
||||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
actions = self.diffusion.generate_actions(batch)
|
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Select a single action given environment observations.
|
"""Select a single action given environment observations.
|
||||||
|
|
||||||
This method handles caching a history of observations and an action trajectory generated by the
|
This method handles caching a history of observations and an action trajectory generated by the
|
||||||
@@ -131,7 +131,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
self._queues = populate_queues(self._queues, batch)
|
self._queues = populate_queues(self._queues, batch)
|
||||||
|
|
||||||
if len(self._queues[ACTION]) == 0:
|
if len(self._queues[ACTION]) == 0:
|
||||||
actions = self.predict_action_chunk(batch)
|
actions = self.predict_action_chunk(batch, noise=noise)
|
||||||
self._queues[ACTION].extend(actions.transpose(0, 1))
|
self._queues[ACTION].extend(actions.transpose(0, 1))
|
||||||
|
|
||||||
action = self._queues[ACTION].popleft()
|
action = self._queues[ACTION].popleft()
|
||||||
@@ -199,17 +199,25 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
# ========= inference ============
|
# ========= inference ============
|
||||||
def conditional_sample(
|
def conditional_sample(
|
||||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
global_cond: Tensor | None = None,
|
||||||
|
generator: torch.Generator | None = None,
|
||||||
|
noise: Tensor | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
device = get_device_from_parameters(self)
|
device = get_device_from_parameters(self)
|
||||||
dtype = get_dtype_from_parameters(self)
|
dtype = get_dtype_from_parameters(self)
|
||||||
|
|
||||||
# Sample prior.
|
# Sample prior.
|
||||||
sample = torch.randn(
|
sample = (
|
||||||
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
noise
|
||||||
dtype=dtype,
|
if noise is not None
|
||||||
device=device,
|
else torch.randn(
|
||||||
generator=generator,
|
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
generator=generator,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
self.noise_scheduler.set_timesteps(self.num_inference_steps)
|
||||||
@@ -264,7 +272,7 @@ class DiffusionModel(nn.Module):
|
|||||||
# Concatenate features then flatten to (B, global_cond_dim).
|
# Concatenate features then flatten to (B, global_cond_dim).
|
||||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||||
|
|
||||||
def generate_actions(self, batch: dict[str, Tensor]) -> Tensor:
|
def generate_actions(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
This function expects `batch` to have:
|
This function expects `batch` to have:
|
||||||
{
|
{
|
||||||
@@ -282,7 +290,7 @@ class DiffusionModel(nn.Module):
|
|||||||
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
|
||||||
|
|
||||||
# run sampling
|
# run sampling
|
||||||
actions = self.conditional_sample(batch_size, global_cond=global_cond)
|
actions = self.conditional_sample(batch_size, global_cond=global_cond, noise=noise)
|
||||||
|
|
||||||
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
# Extract `n_action_steps` steps worth of actions (from the current observation).
|
||||||
start = n_obs_steps - 1
|
start = n_obs_steps - 1
|
||||||
|
|||||||
@@ -253,7 +253,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
return super().from_pretrained(*args, **kwargs)
|
return super().from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
raise NotImplementedError("Currently not implemented for PI0")
|
raise NotImplementedError("Currently not implemented for PI0")
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ import os
|
|||||||
from importlib.resources import files
|
from importlib.resources import files
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryDirectory
|
from tempfile import TemporaryDirectory
|
||||||
from typing import TypeVar
|
from typing import TypedDict, TypeVar
|
||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
import safetensors
|
import safetensors
|
||||||
@@ -27,6 +27,7 @@ from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
@@ -36,6 +37,10 @@ from lerobot.utils.hub import HubMixin
|
|||||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||||
|
|
||||||
|
|
||||||
|
class ActionSelectKwargs(TypedDict, total=False):
|
||||||
|
noise: Tensor | None
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||||
"""
|
"""
|
||||||
Base class for policy models.
|
Base class for policy models.
|
||||||
@@ -181,7 +186,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||||
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
|
"""Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode.
|
||||||
|
|
||||||
Child classes using action chunking should use this method within `select_action` to form the action chunk
|
Child classes using action chunking should use this method within `select_action` to form the action chunk
|
||||||
@@ -190,7 +195,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||||
"""Return one action to run in the environment (potentially in batch mode).
|
"""Return one action to run in the environment (potentially in batch mode).
|
||||||
|
|
||||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||||
|
|||||||
Reference in New Issue
Block a user