Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -0,0 +1,4 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig

View File

@@ -15,9 +15,14 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamWConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("act")
@dataclass
class ACTConfig:
class ACTConfig(PreTrainedConfig):
"""Configuration class for the Action Chunking Transformers policy.
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
@@ -90,28 +95,11 @@ class ACTConfig:
chunk_size: int = 100
n_action_steps: int = 100
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.images.top": "mean_std",
"observation.state": "mean_std",
}
)
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
@@ -144,7 +132,14 @@ class ACTConfig:
dropout: float = 0.1
kl_weight: float = 10.0
# Training preset
optimizer_lr: float = 1e-5
optimizer_weight_decay: float = 1e-4
optimizer_lr_backbone: float = 1e-5
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
@@ -164,8 +159,28 @@ class ACTConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if (
not any(k.startswith("observation.image") for k in self.input_shapes)
and "observation.environment_state" not in self.input_shapes
):
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -29,32 +29,27 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.misc import FrozenBatchNorm2d
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
class ACTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "act"],
):
class ACTPolicy(PreTrainedPolicy):
"""
Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost
Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act)
"""
config_class = ACTConfig
name = "act"
def __init__(
self,
config: ACTConfig | None = None,
config: ACTConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -64,30 +59,46 @@ class ACTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = ACTConfig()
self.config: ACTConfig = config
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = ACT(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
def get_optim_params(self) -> dict:
# TODO(aliberts, rcadene): As of now, lr_backbone == lr
# Should we remove this and just `return self.parameters()`?
return [
{
"params": [
p
for n, p in self.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in self.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": self.config.optimizer_lr_backbone,
},
]
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_coeff is not None:
@@ -106,9 +117,11 @@ class ACTPolicy(
self.eval()
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -134,9 +147,11 @@ class ACTPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -288,31 +303,30 @@ class ACT(nn.Module):
"""
def __init__(self, config: ACTConfig):
super().__init__()
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_robot_state = "observation.state" in config.input_shapes
self.use_images = any(k.startswith("observation.image") for k in config.input_shapes)
self.use_env_state = "observation.environment_state" in config.input_shapes
super().__init__()
self.config = config
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config, is_vae_encoder=True)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
if self.use_robot_state:
if self.config.robot_state_feature:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.output_shapes["action"][0], config.dim_model
self.config.action_feature.shape[0],
config.dim_model,
)
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_robot_state:
if self.config.robot_state_feature:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
@@ -320,7 +334,7 @@ class ACT(nn.Module):
)
# Backbone for image feature extraction.
if self.use_images:
if self.config.image_features:
backbone_model = getattr(torchvision.models, config.vision_backbone)(
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
weights=config.pretrained_backbone_weights,
@@ -337,27 +351,27 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, (robot_state), (env_state), (image_feature_map_pixels)].
if self.use_robot_state:
if self.config.robot_state_feature:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
self.config.robot_state_feature.shape[0], config.dim_model
)
if self.use_env_state:
if self.config.env_state_feature:
self.encoder_env_state_input_proj = nn.Linear(
config.input_shapes["observation.environment_state"][0], config.dim_model
self.config.env_state_feature.shape[0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
n_1d_tokens = 1 # for the latent
if self.use_robot_state:
if self.config.robot_state_feature:
n_1d_tokens += 1
if self.use_env_state:
if self.config.env_state_feature:
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.use_images:
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -365,7 +379,7 @@ class ACT(nn.Module):
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0])
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self._reset_parameters()
@@ -380,13 +394,13 @@ class ACT(nn.Module):
`batch` should have the following structure:
{
"observation.state" (optional): (B, state_dim) batch of robot states.
[robot_state_feature] (optional): (B, state_dim) batch of robot states.
"observation.images": (B, n_cameras, C, H, W) batch of images.
[image_features]: (B, n_cameras, C, H, W) batch of images.
AND/OR
"observation.environment_state": (B, env_dim) batch of environment states.
[env_state_feature]: (B, env_dim) batch of environment states.
"action" (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
[action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions.
}
Returns:
@@ -411,12 +425,12 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.use_robot_state:
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
if self.use_robot_state:
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
@@ -430,7 +444,7 @@ class ACT(nn.Module):
# sequence depending whether we use the input states or not (cls and robot state)
# False means not a padding token.
cls_joint_is_pad = torch.full(
(batch_size, 2 if self.use_robot_state else 1),
(batch_size, 2 if self.config.robot_state_feature else 1),
False,
device=batch["observation.state"].device,
)
@@ -463,16 +477,16 @@ class ACT(nn.Module):
encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)]
encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1))
# Robot state token.
if self.use_robot_state:
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
# Environment state token.
if self.use_env_state:
if self.config.env_state_feature:
encoder_in_tokens.append(
self.encoder_env_state_input_proj(batch["observation.environment_state"])
)
# Camera observation features and positional embeddings.
if self.use_images:
if self.config.image_features:
all_cam_features = []
all_cam_pos_embeds = []

View File

@@ -16,9 +16,15 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import DiffuserSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("diffusion")
@dataclass
class DiffusionConfig:
class DiffusionConfig(PreTrainedConfig):
"""Configuration class for DiffusionPolicy.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -102,26 +108,17 @@ class DiffusionConfig:
horizon: int = 16
n_action_steps: int = 8
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# The original implementation doesn't sample frames for the last 7 steps,
# which avoids excessive padding and leads to improved training results.
drop_n_last_frames: int = 7 # horizon - n_action_steps - n_obs_steps + 1
# Architecture / modeling.
# Vision backbone.
@@ -154,39 +151,23 @@ class DiffusionConfig:
# Loss computation
do_mask_loss_for_padding: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
scheduler_name: str = "cosine"
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if len(image_keys) > 0:
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
)
supported_prediction_types = ["epsilon", "sample"]
if self.prediction_type not in supported_prediction_types:
raise ValueError(
@@ -207,3 +188,50 @@ class DiffusionConfig:
"The horizon should be an integer multiple of the downsampling factor (which is determined "
f"by `len(down_dims)`). Got {self.horizon=} and {self.down_dims=}"
)
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
return DiffuserSchedulerConfig(
name=self.scheduler_name,
num_warmup_steps=self.scheduler_warmup_steps,
)
def validate_features(self) -> None:
if len(self.image_features) == 0 and self.env_state_feature is None:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
if self.crop_shape is not None:
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1 - self.n_obs_steps + self.horizon))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -31,35 +31,32 @@ import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import (
get_device_from_parameters,
get_dtype_from_parameters,
get_output_shape,
populate_queues,
)
class DiffusionPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "diffusion-policy"],
):
class DiffusionPolicy(PreTrainedPolicy):
"""
Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
(paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy).
"""
config_class = DiffusionConfig
name = "diffusion"
def __init__(
self,
config: DiffusionConfig | None = None,
config: DiffusionConfig,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -69,18 +66,16 @@ class DiffusionPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = DiffusionConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
@@ -88,20 +83,20 @@ class DiffusionPolicy(
self.diffusion = DiffusionModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.use_env_state = "observation.environment_state" in config.input_shapes
self.reset()
def get_optim_params(self) -> dict:
return self.diffusion.parameters()
def reset(self):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
"observation.state": deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
}
if len(self.expected_image_keys) > 0:
if self.config.image_features:
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
if self.use_env_state:
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
@@ -127,9 +122,11 @@ class DiffusionPolicy(
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -149,9 +146,11 @@ class DiffusionPolicy(
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if len(self.expected_image_keys) > 0:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack(
[batch[key] for key in self.config.image_features], dim=-4
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
@@ -176,12 +175,9 @@ class DiffusionModel(nn.Module):
self.config = config
# Build observation encoders (depending on which observations are provided).
global_cond_dim = config.input_shapes["observation.state"][0]
num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self._use_images = False
self._use_env_state = False
if num_images > 0:
self._use_images = True
global_cond_dim = self.config.robot_state_feature.shape[0]
if self.config.image_features:
num_images = len(self.config.image_features)
if self.config.use_separate_rgb_encoder_per_camera:
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
self.rgb_encoder = nn.ModuleList(encoders)
@@ -189,9 +185,8 @@ class DiffusionModel(nn.Module):
else:
self.rgb_encoder = DiffusionRgbEncoder(config)
global_cond_dim += self.rgb_encoder.feature_dim * num_images
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
global_cond_dim += config.input_shapes["observation.environment_state"][0]
if self.config.env_state_feature:
global_cond_dim += self.config.env_state_feature.shape[0]
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
@@ -220,7 +215,7 @@ class DiffusionModel(nn.Module):
# Sample prior.
sample = torch.randn(
size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]),
size=(batch_size, self.config.horizon, self.config.action_feature.shape[0]),
dtype=dtype,
device=device,
generator=generator,
@@ -242,10 +237,10 @@ class DiffusionModel(nn.Module):
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
"""Encode image features and concatenate them all together along with the state vector."""
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
global_cond_feats = [batch["observation.state"]]
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
global_cond_feats = [batch[OBS_ROBOT]]
# Extract image features.
if self._use_images:
if self.config.image_features:
if self.config.use_separate_rgb_encoder_per_camera:
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
@@ -272,8 +267,8 @@ class DiffusionModel(nn.Module):
)
global_cond_feats.append(img_features)
if self._use_env_state:
global_cond_feats.append(batch["observation.environment_state"])
if self.config.env_state_feature:
global_cond_feats.append(batch[OBS_ENV])
# Concatenate features then flatten to (B, global_cond_dim).
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
@@ -443,7 +438,7 @@ class SpatialSoftmax(nn.Module):
class DiffusionRgbEncoder(nn.Module):
"""Encoder an RGB image into a 1D feature vector.
"""Encodes an RGB image into a 1D feature vector.
Includes the ability to normalize and crop the image first.
"""
@@ -482,19 +477,16 @@ class DiffusionRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# height and width from `config.image_features`.
# Note: we have a check in the config class to make sure all images have the same shape.
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -611,7 +603,7 @@ class DiffusionConditionalUnet1d(nn.Module):
# In channels / out channels for each downsampling block in the Unet's encoder. For the decoder, we
# just reverse these.
in_out = [(config.output_shapes["action"][0], config.down_dims[0])] + list(
in_out = [(config.action_feature.shape[0], config.down_dims[0])] + list(
zip(config.down_dims[:-1], config.down_dims[1:], strict=True)
)
@@ -666,7 +658,7 @@ class DiffusionConditionalUnet1d(nn.Module):
self.final_conv = nn.Sequential(
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1),
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
)
def forward(self, x: Tensor, timestep: Tensor | int, global_cond=None) -> Tensor:

View File

@@ -13,99 +13,132 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import logging
from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import get_safe_torch_device
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.envs.configs import EnvConfig
from lerobot.common.envs.utils import env_to_policy_features
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType
def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg):
expected_kwargs = set(inspect.signature(policy_cfg_class).parameters)
if not set(hydra_cfg.policy).issuperset(expected_kwargs):
logging.warning(
f"Hydra config is missing arguments: {set(expected_kwargs).difference(hydra_cfg.policy)}"
)
# OmegaConf.to_container returns lists where sequences are found, but our dataclasses use tuples to avoid
# issues with mutable defaults. This filter changes all lists to tuples.
def list_to_tuple(item):
return tuple(item) if isinstance(item, list) else item
policy_cfg = policy_cfg_class(
**{
k: list_to_tuple(v)
for k, v in OmegaConf.to_container(hydra_cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
return policy_cfg
def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
def get_policy_class(name: str) -> PreTrainedPolicy:
"""Get the policy's class and config class given a name (matching the policy class' `name` attribute)."""
if name == "tdmpc":
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
return TDMPCPolicy, TDMPCConfig
return TDMPCPolicy
elif name == "diffusion":
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
return DiffusionPolicy, DiffusionConfig
return DiffusionPolicy
elif name == "act":
from lerobot.common.policies.act.configuration_act import ACTConfig
from lerobot.common.policies.act.modeling_act import ACTPolicy
return ACTPolicy, ACTConfig
return ACTPolicy
elif name == "vqbet":
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy, VQBeTConfig
return VQBeTPolicy
else:
raise NotImplementedError(f"Policy with name {name} is not implemented.")
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
if policy_type == "tdmpc":
return TDMPCConfig(**kwargs)
elif policy_type == "diffusion":
return DiffusionConfig(**kwargs)
elif policy_type == "act":
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
else:
raise ValueError(f"Policy type '{policy_type}' is not available.")
def make_policy(
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
) -> Policy:
cfg: PreTrainedConfig,
device: str | torch.device,
ds_meta: LeRobotDatasetMetadata | None = None,
env_cfg: EnvConfig | None = None,
) -> PreTrainedPolicy:
"""Make an instance of a policy class.
This function exists because (for now) we need to parse features from either a dataset or an environment
in order to properly dimension and instantiate a policy for that dataset or environment.
Args:
hydra_cfg: A parsed Hydra configuration (see scripts). If `pretrained_policy_name_or_path` is
provided, only `hydra_cfg.policy.name` is used while everything else is ignored.
pretrained_policy_name_or_path: Either the repo ID of a model hosted on the Hub or a path to a
directory containing weights saved using `Policy.save_pretrained`. Note that providing this
argument overrides everything in `hydra_cfg.policy` apart from `hydra_cfg.policy.name`.
dataset_stats: Dataset statistics to use for (un)normalization of inputs/outputs in the policy. Must
be provided when initializing a new policy, and must not be provided when loading a pretrained
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will
be loaded with the weights from that path.
device (str): the device to load the policy onto.
ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and
statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None.
env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be
provided if ds_meta is not. Defaults to None.
Raises:
ValueError: Either ds_meta or env and env_cfg must be provided.
NotImplementedError: if the policy.type is 'vqbet' and the device 'mps' (due to an incompatibility)
Returns:
PreTrainedPolicy: _description_
"""
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
raise ValueError(
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
if bool(ds_meta) == bool(env_cfg):
raise ValueError("Either one of a dataset metadata or a sim env must be provided.")
# NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error.
# TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies?
# NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If
# you want this op to be added in priority during the prototype phase of this feature, please comment on
# https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment
# variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be
# slower than running natively on MPS.
if cfg.type == "vqbet" and str(device) == "mps":
raise NotImplementedError(
"Current implementation of VQBeT does not support `mps` backend. "
"Please use `cpu` or `cuda` backend."
)
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
policy_cls = get_policy_class(cfg.type)
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
if pretrained_policy_name_or_path is None:
# Make a fresh policy.
policy = policy_cls(policy_cfg, dataset_stats)
kwargs = {}
if ds_meta is not None:
features = dataset_to_policy_features(ds_meta.features)
kwargs["dataset_stats"] = ds_meta.stats
else:
if not cfg.pretrained_path:
logging.warning(
"You are instantiating a policy from scratch and its features are parsed from an environment "
"rather than a dataset. Normalization modules inside the policy will have infinite values "
"by default without stats from a dataset."
)
features = env_to_policy_features(env_cfg)
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
kwargs["config"] = cfg
if cfg.pretrained_path:
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
# hyperparameters that we want to vary).
# TODO(alexander-soare): This hack makes use of huggingface_hub's tooling to load the policy with,
# pretrained weights which are then loaded into a fresh policy with the desired config. This PR in
# huggingface_hub should make it possible to avoid the hack:
# https://github.com/huggingface/huggingface_hub/pull/2274.
policy = policy_cls(policy_cfg)
policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict())
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
policy = policy_cls.from_pretrained(**kwargs)
else:
# Make a fresh policy.
policy = policy_cls(**kwargs)
policy.to(get_safe_torch_device(hydra_cfg.device))
policy.to(device)
assert isinstance(policy, nn.Module)
return policy

View File

@@ -16,10 +16,12 @@
import torch
from torch import Tensor, nn
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
def create_stats_buffers(
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
@@ -34,12 +36,16 @@ def create_stats_buffers(
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
for key, ft in features.items():
norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
shape = tuple(shapes[key])
assert isinstance(norm_mode, NormalizationMode)
if "image" in key:
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
# sanity checks
assert len(shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
@@ -52,7 +58,7 @@ def create_stats_buffers(
# we assert they are not infinity anymore.
buffer = {}
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
@@ -61,7 +67,7 @@ def create_stats_buffers(
"std": nn.Parameter(std, requires_grad=False),
}
)
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict(
@@ -71,15 +77,15 @@ def create_stats_buffers(
}
)
if stats is not None:
if stats:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone()
buffer["std"].data = stats[key]["std"].clone()
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone()
buffer["max"].data = stats[key]["max"].clone()
@@ -99,8 +105,8 @@ class Normalize(nn.Module):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -122,10 +128,10 @@ class Normalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -133,16 +139,20 @@ class Normalize(nn.Module):
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -152,7 +162,7 @@ class Normalize(nn.Module):
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
raise ValueError(norm_mode)
return batch
@@ -164,8 +174,8 @@ class Unnormalize(nn.Module):
def __init__(
self,
shapes: dict[str, list[int]],
modes: dict[str, str],
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
@@ -187,11 +197,11 @@ class Unnormalize(nn.Module):
dataset is not needed to get the stats, since they are already in the policy state_dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.features = features
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(shapes, modes, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -199,16 +209,20 @@ class Unnormalize(nn.Module):
@torch.no_grad
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) # shallow copy avoids mutating the input batch
for key, mode in self.modes.items():
for key, ft in self.features.items():
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
if norm_mode is NormalizationMode.IDENTITY:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
if norm_mode is NormalizationMode.MEAN_STD:
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
elif norm_mode is NormalizationMode.MIN_MAX:
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
@@ -216,5 +230,5 @@ class Unnormalize(nn.Module):
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
raise ValueError(norm_mode)
return batch

View File

@@ -1,75 +0,0 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A protocol that all policies should follow.
This provides a mechanism for type-hinting and isinstance checks without requiring the policies classes
subclass a base class.
The protocol structure, method signatures, and docstrings should be used by developers as a reference for
how to implement new policies.
"""
from typing import Protocol, runtime_checkable
from torch import Tensor
@runtime_checkable
class Policy(Protocol):
"""The required interface for implementing a policy.
We also expect all policies to subclass torch.nn.Module and PyTorchModelHubMixin.
"""
name: str
def __init__(self, cfg, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
dataset_stats: Dataset statistics to be used for normalization.
"""
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
@runtime_checkable
class PolicyWithUpdate(Policy, Protocol):
def update(self):
"""An update method that is to be called after a training optimization step.
Implements an additional updates the model parameters may need (for example, doing an EMA step for a
target model, or incrementing an internal buffer).
"""

View File

@@ -0,0 +1,182 @@
import abc
import logging
import os
from pathlib import Path
from typing import Type, TypeVar
import packaging
import safetensors
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor
from safetensors.torch import save_model as save_model_as_safetensor
from torch import Tensor, nn
from lerobot.common.utils.hub import HubMixin
from lerobot.configs.policies import PreTrainedConfig
T = TypeVar("T", bound="PreTrainedPolicy")
DEFAULT_POLICY_CARD = """
---
# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
# Doc / guide: https://huggingface.co/docs/hub/model-cards
{{ card_data }}
---
This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot):
- Docs: {{ docs_url | default("[More Information Needed]", true) }}
"""
class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
Base class for policy models.
"""
config_class: None
name: None
def __init__(self, config: PreTrainedConfig, *inputs, **kwargs):
super().__init__()
if not isinstance(config, PreTrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PreTrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
self.config = config
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not getattr(cls, "config_class", None):
raise TypeError(f"Class {cls.__name__} must define 'config_class'")
if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'")
def _save_pretrained(self, save_directory: Path) -> None:
self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
@classmethod
def from_pretrained(
cls: Type[T],
pretrained_name_or_path: str | Path,
*,
config: PreTrainedConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
map_location: str = "cpu",
strict: bool = False,
**kwargs,
) -> T:
"""
The policy is set in evaluation mode by default using `policy.eval()` (dropout modules are
deactivated). To train it, you should first set it back in training mode with `policy.train()`.
"""
if config is None:
config = PreTrainedConfig.from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
**kwargs,
)
model_id = str(pretrained_name_or_path)
instance = cls(config, **kwargs)
if os.path.isdir(model_id):
print("Loading weights from local directory")
model_file = os.path.join(model_id, SAFETENSORS_SINGLE_FILE)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
else:
try:
model_file = hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
token=token,
local_files_only=local_files_only,
)
policy = cls._load_as_safetensor(instance, model_file, map_location, strict)
except HfHubHTTPError as e:
raise FileNotFoundError(
f"{SAFETENSORS_SINGLE_FILE} not found on the HuggingFace Hub in {model_id}"
) from e
policy.to(map_location)
policy.eval()
return policy
@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
load_model_as_safetensor(model, model_file, strict=strict)
if map_location != "cpu":
logging.warning(
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
" This means that the model is loaded on 'cpu' first and then copied to the device."
" This leads to a slower loading time."
" Please update safetensors to version 0.4.3 or above for improved performance."
)
model.to(map_location)
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
return model
# def generate_model_card(self, *args, **kwargs) -> ModelCard:
# card = ModelCard.from_template(
# card_data=self._hub_mixin_info.model_card_data,
# template_str=self._hub_mixin_info.model_card_template,
# repo_url=self._hub_mixin_info.repo_url,
# docs_url=self._hub_mixin_info.docs_url,
# **kwargs,
# )
# return card
@abc.abstractmethod
def get_optim_params(self) -> dict:
"""
Returns the policy-specific parameters dict to be passed on to the optimizer.
"""
raise NotImplementedError
@abc.abstractmethod
def reset(self):
"""To be called whenever the environment is reset.
Does things like clearing caches.
"""
raise NotImplementedError
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
raise NotImplementedError
@abc.abstractmethod
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
raise NotImplementedError

View File

@@ -16,9 +16,14 @@
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("tdmpc")
@dataclass
class TDMPCConfig:
class TDMPCConfig(PreTrainedConfig):
"""Configuration class for TDMPCPolicy.
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
@@ -102,27 +107,19 @@ class TDMPCConfig:
"""
# Input / output structure.
n_obs_steps: int = 1
n_action_repeats: int = 2
horizon: int = 5
n_action_steps: int = 1
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 84, 84],
"observation.state": [4],
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"ENV": NormalizationMode.IDENTITY,
"ACTION": NormalizationMode.MIN_MAX,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
# Architecture / modeling.
# Neural networks.
@@ -159,32 +156,27 @@ class TDMPCConfig:
# Target model.
target_model_momentum: float = 0.995
# Training presets
optimizer_lr: float = 3e-4
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
# There should only be one image key.
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
if len(image_keys) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
)
if len(image_keys) > 0:
image_key = next(iter(image_keys))
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
)
if self.n_gaussian_samples <= 0:
raise ValueError(
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
)
if self.output_normalization_modes != {"action": "min_max"}:
if self.normalization_mapping["ACTION"] is not NormalizationMode.MIN_MAX:
raise ValueError(
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
"information."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.n_action_steps > 1:
if self.n_action_repeats != 1:
raise ValueError(
@@ -194,3 +186,35 @@ class TDMPCConfig:
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
if self.n_action_steps > self.horizon:
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(lr=self.optimizer_lr)
def get_scheduler_preset(self) -> None:
return None
def validate_features(self) -> None:
# There should only be one image key.
if len(self.image_features) > 1:
raise ValueError(
f"{self.__class__.__name__} handles at most one image for now. Got image keys {self.image_features}."
)
if len(self.image_features) > 0:
image_ft = next(iter(self.image_features.values()))
if image_ft.shape[-2] != image_ft.shape[-1]:
# TODO(alexander-soare): This limitation is solely because of code in the random shift
# augmentation. It should be able to be removed.
raise ValueError(f"Only square images are handled now. Got image shape {image_ft.shape}.")
@property
def observation_delta_indices(self) -> list:
return list(range(self.horizon + 1))
@property
def action_delta_indices(self) -> list:
return list(range(self.horizon))
@property
def reward_delta_indices(self) -> None:
return list(range(self.horizon))

View File

@@ -33,21 +33,16 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
class TDMPCPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "tdmpc"],
):
class TDMPCPolicy(PreTrainedPolicy):
"""Implementation of TD-MPC learning + inference.
Please note several warnings for this policy.
@@ -65,11 +60,10 @@ class TDMPCPolicy(
match our xarm environment.
"""
config_class = TDMPCConfig
name = "tdmpc"
def __init__(
self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
@@ -77,42 +71,28 @@ class TDMPCPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = TDMPCConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_features, config.normalization_mapping, dataset_stats
)
self.model = TDMPCTOLD(config)
self.model_target = deepcopy(self.model)
for param in self.model_target.parameters():
param.requires_grad = False
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
self._use_image = False
self._use_env_state = False
if len(image_keys) > 0:
assert len(image_keys) == 1
self._use_image = True
self.input_image_key = image_keys[0]
if "observation.environment_state" in config.input_shapes:
self._use_env_state = True
self.reset()
def get_optim_params(self) -> dict:
return self.parameters()
def reset(self):
"""
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
@@ -122,9 +102,9 @@ class TDMPCPolicy(
"observation.state": deque(maxlen=1),
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
}
if self._use_image:
if self.config.image_features:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=1)
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
# CEM for the next step.
@@ -134,9 +114,9 @@ class TDMPCPolicy(
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
self._queues = populate_queues(self._queues, batch)
@@ -151,9 +131,9 @@ class TDMPCPolicy(
# NOTE: Order of observations matters here.
encode_keys = []
if self._use_image:
if self.config.image_features:
encode_keys.append("observation.image")
if self._use_env_state:
if self.config.env_state_feature:
encode_keys.append("observation.environment_state")
encode_keys.append("observation.state")
z = self.model.encode({k: batch[k] for k in encode_keys})
@@ -196,7 +176,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_pi_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=device,
)
if self.config.n_pi_samples > 0:
@@ -215,7 +195,7 @@ class TDMPCPolicy(
# algorithm.
# The initial mean and standard deviation for the cross-entropy method (CEM).
mean = torch.zeros(
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
)
# Maybe warm start CEM with the mean from the previous step.
if self._prev_mean is not None:
@@ -228,7 +208,7 @@ class TDMPCPolicy(
self.config.horizon,
self.config.n_gaussian_samples,
batch_size,
self.config.output_shapes["action"][0],
self.config.action_feature.shape[0],
device=std.device,
)
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
@@ -330,9 +310,9 @@ class TDMPCPolicy(
device = get_device_from_parameters(self)
batch = self.normalize_inputs(batch)
if self._use_image:
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.image"] = batch[self.input_image_key]
batch["observation.image"] = batch[next(iter(self.config.image_features))]
batch = self.normalize_targets(batch)
info = {}
@@ -347,7 +327,7 @@ class TDMPCPolicy(
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
# Apply random image augmentations.
if self._use_image and self.config.max_random_shift_ratio > 0:
if self.config.image_features and self.config.max_random_shift_ratio > 0:
observations["observation.image"] = flatten_forward_unflatten(
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
observations["observation.image"],
@@ -360,7 +340,7 @@ class TDMPCPolicy(
current_observation[k] = observations[k][0]
next_observations[k] = observations[k][1:]
horizon, batch_size = next_observations[
"observation.image" if self._use_image else "observation.environment_state"
"observation.image" if self.config.image_features else "observation.environment_state"
].shape[:2]
# Run latent rollout using the latent dynamics model and policy model.
@@ -543,7 +523,7 @@ class TDMPCTOLD(nn.Module):
self.config = config
self._encoder = TDMPCObservationEncoder(config)
self._dynamics = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -554,7 +534,7 @@ class TDMPCTOLD(nn.Module):
nn.Sigmoid(),
)
self._reward = nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -569,12 +549,12 @@ class TDMPCTOLD(nn.Module):
nn.Linear(config.mlp_dim, config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Mish(),
nn.Linear(config.mlp_dim, config.output_shapes["action"][0]),
nn.Linear(config.mlp_dim, config.action_feature.shape[0]),
)
self._Qs = nn.ModuleList(
[
nn.Sequential(
nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
nn.LayerNorm(config.mlp_dim),
nn.Tanh(),
nn.Linear(config.mlp_dim, config.mlp_dim),
@@ -714,10 +694,13 @@ class TDMPCObservationEncoder(nn.Module):
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
if config.image_features:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
next(iter(config.image_features.values())).shape[0],
config.image_encoder_hidden_dim,
7,
stride=2,
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
@@ -727,9 +710,8 @@ class TDMPCObservationEncoder(nn.Module):
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
@@ -738,19 +720,19 @@ class TDMPCObservationEncoder(nn.Module):
nn.Sigmoid(),
)
)
if "observation.state" in config.input_shapes:
if config.robot_state_feature:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Sigmoid(),
)
if "observation.environment_state" in config.input_shapes:
if config.env_state_feature:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
@@ -765,12 +747,16 @@ class TDMPCObservationEncoder(nn.Module):
"""
feat = []
# NOTE: Order of observations matters here.
if "observation.image" in self.config.input_shapes:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
if self.config.image_features:
feat.append(
flatten_forward_unflatten(
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
)
)
if self.config.env_state_feature:
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
if self.config.robot_state_feature:
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
return torch.stack(feat, dim=0).mean(0)

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
@@ -47,3 +48,20 @@ def get_dtype_from_parameters(module: nn.Module) -> torch.dtype:
Note: assumes that all parameters have the same dtype.
"""
return next(iter(module.parameters())).dtype
def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
"""
Calculates the output shape of a PyTorch module given an input shape.
Args:
module (nn.Module): a PyTorch module
input_shape (tuple): A tuple representing the input shape, e.g., (batch_size, channels, height, width)
Returns:
tuple: The output shape of the module.
"""
dummy_input = torch.zeros(size=input_shape)
with torch.inference_mode():
output = module(dummy_input)
return tuple(output.shape)

View File

@@ -18,9 +18,15 @@
from dataclasses import dataclass, field
from lerobot.common.optim.optimizers import AdamConfig
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
@PreTrainedConfig.register_subclass("vqbet")
@dataclass
class VQBeTConfig:
class VQBeTConfig(PreTrainedConfig):
"""Configuration class for VQ-BeT.
Defaults are configured for training with PushT providing proprioceptive and single camera observations.
@@ -90,26 +96,13 @@ class VQBeTConfig:
n_action_pred_token: int = 3
action_chunk_size: int = 5
input_shapes: dict[str, list[int]] = field(
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"observation.image": [3, 96, 96],
"observation.state": [2],
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [2],
}
)
# Normalization / Unnormalization
input_normalization_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "min_max",
}
)
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
# Architecture / modeling.
# Vision backbone.
@@ -139,29 +132,69 @@ class VQBeTConfig:
bet_softmax_temperature: float = 0.1
sequentially_select: bool = False
# Training presets
optimizer_lr: float = 1e-4
optimizer_betas: tuple = (0.95, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-6
optimizer_vqvae_lr: float = 1e-3
optimizer_vqvae_weight_decay: float = 1e-4
scheduler_warmup_steps: int = 500
def __post_init__(self):
super().__post_init__()
"""Input validation (not exhaustive)."""
if not self.vision_backbone.startswith("resnet"):
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
def get_optimizer_preset(self) -> AdamConfig:
return AdamConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self) -> VQBeTSchedulerConfig:
return VQBeTSchedulerConfig(
num_warmup_steps=self.scheduler_warmup_steps,
num_vqvae_training_steps=self.n_vqvae_training_steps,
)
def validate_features(self) -> None:
# Note: this check was previously performed inside VQBeTRgbEncoder in the form of
# assert len(image_keys) == 1
if not len(self.image_features) == 1:
raise ValueError("You must provide only one image among the inputs.")
if self.crop_shape is not None:
for image_key in image_keys:
if (
self.crop_shape[0] > self.input_shapes[image_key][1]
or self.crop_shape[1] > self.input_shapes[image_key][2]
):
for key, image_ft in self.image_features.items():
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
raise ValueError(
f"`crop_shape` should fit within `input_shapes[{image_key}]`. Got {self.crop_shape} "
f"for `crop_shape` and {self.input_shapes[image_key]} for "
"`input_shapes[{image_key}]`."
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
f"for `crop_shape` and {image_ft.shape} for "
f"`{key}`."
)
# Check that all input images have the same shape.
first_image_key = next(iter(image_keys))
for image_key in image_keys:
if self.input_shapes[image_key] != self.input_shapes[first_image_key]:
first_image_key, first_image_ft = next(iter(self.image_features.items()))
for key, image_ft in self.image_features.items():
if image_ft.shape != first_image_ft.shape:
raise ValueError(
f"`input_shapes[{image_key}]` does not match `input_shapes[{first_image_key}]`, but we "
"expect all image shapes to match."
f"`{key}` does not match `{first_image_key}`, but we " "expect all image shapes to match."
)
@property
def observation_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, 1))
@property
def action_delta_indices(self) -> list:
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import warnings
from collections import deque
from typing import Callable, List
@@ -26,29 +25,23 @@ import numpy as np
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from torch.optim.lr_scheduler import LambdaLR
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
# ruff: noqa: N806
class VQBeTPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vqbet"],
):
class VQBeTPolicy(PreTrainedPolicy):
"""
VQ-BeT Policy as per "Behavior Generation with Latent Actions"
"""
config_class = VQBeTConfig
name = "vqbet"
def __init__(
@@ -63,26 +56,62 @@ class VQBeTPolicy(
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
config = VQBeTConfig()
super().__init__(config)
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
config.output_features, config.normalization_mapping, dataset_stats
)
self.vqbet = VQBeTModel(config)
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
self.reset()
def get_optim_params(self) -> dict:
vqvae_params = (
list(self.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(self.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = self.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(self.vqbet.rgb_encoder.parameters())
+ list(self.vqbet.state_projector.parameters())
+ list(self.vqbet.rgb_feature_projector.parameters())
+ [self.vqbet.action_token]
+ list(self.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if self.config.sequentially_select:
decay_params = (
decay_params
+ list(self.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(self.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(self.vqbet.action_head.map_to_cbet_preds_bin.parameters())
return [
{
"params": decay_params,
},
{
"params": vqvae_params,
"weight_decay": self.config.optimizer_vqvae_weight_decay,
"lr": self.config.optimizer_vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
},
]
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
@@ -105,7 +134,7 @@ class VQBeTPolicy(
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
@@ -131,7 +160,7 @@ class VQBeTPolicy(
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
batch = self.normalize_targets(batch)
# VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181)
if not self.vqbet.action_head.vqvae_model.discretized.item():
@@ -288,14 +317,14 @@ class VQBeTModel(nn.Module):
self.config = config
self.rgb_encoder = VQBeTRgbEncoder(config)
self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")])
self.num_images = len(self.config.image_features)
# This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above.
# Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results.
self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim))
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
self.state_projector = MLP(
config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim]
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
)
self.rgb_feature_projector = MLP(
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
@@ -350,10 +379,10 @@ class VQBeTModel(nn.Module):
# get action features (pass through GPT)
features = self.policy(input_tokens)
# len(self.config.input_shapes) is the number of different observation modes.
# len(self.config.input_features) is the number of different observation modes.
# this line gets the index of action prompt tokens.
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len(
self.config.input_shapes
historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_features) + 1) + len(
self.config.input_features
)
# only extract the output tokens at the position of action query:
@@ -392,7 +421,7 @@ class VQBeTHead(nn.Module):
self.map_to_cbet_preds_offset: output the predicted offsets for all the codes in all the layers.
The input dimension of ` self.map_to_cbet_preds_offset` is same with the output of GPT,
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.output_shapes["action"][0]`.
and the output dimension of ` self.map_to_cbet_preds_offset` is `self.vqvae_model.vqvae_num_layers (=fixed as 2) * self.config.vqvae_n_embed * config.action_chunk_size * config.action_feature.shape[0]`.
"""
super().__init__()
@@ -419,7 +448,7 @@ class VQBeTHead(nn.Module):
self.vqvae_model.vqvae_num_layers
* self.config.vqvae_n_embed
* config.action_chunk_size
* config.output_shapes["action"][0],
* config.action_feature.shape[0],
],
)
# loss
@@ -623,84 +652,6 @@ class VQBeTHead(nn.Module):
return loss_dict
class VQBeTOptimizer(torch.optim.Adam):
def __init__(self, policy, cfg):
vqvae_params = (
list(policy.vqbet.action_head.vqvae_model.encoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.decoder.parameters())
+ list(policy.vqbet.action_head.vqvae_model.vq_layer.parameters())
)
decay_params, no_decay_params = policy.vqbet.policy.configure_parameters()
decay_params = (
decay_params
+ list(policy.vqbet.rgb_encoder.parameters())
+ list(policy.vqbet.state_projector.parameters())
+ list(policy.vqbet.rgb_feature_projector.parameters())
+ [policy.vqbet.action_token]
+ list(policy.vqbet.action_head.map_to_cbet_preds_offset.parameters())
)
if cfg.policy.sequentially_select:
decay_params = (
decay_params
+ list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters())
+ list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters())
)
else:
decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters())
optim_groups = [
{
"params": decay_params,
"weight_decay": cfg.training.adam_weight_decay,
"lr": cfg.training.lr,
},
{
"params": vqvae_params,
"weight_decay": 0.0001,
"lr": cfg.training.vqvae_lr,
},
{
"params": no_decay_params,
"weight_decay": 0.0,
"lr": cfg.training.lr,
},
]
super().__init__(
optim_groups,
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
)
class VQBeTScheduler(nn.Module):
def __init__(self, optimizer, cfg):
super().__init__()
n_vqvae_training_steps = cfg.training.n_vqvae_training_steps
num_warmup_steps = cfg.training.lr_warmup_steps
num_training_steps = cfg.training.offline_steps
num_cycles = 0.5
def lr_lambda(current_step):
if current_step < n_vqvae_training_steps:
return float(1)
else:
current_step = current_step - n_vqvae_training_steps
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1)
def step(self):
self.lr_scheduler.step()
class VQBeTRgbEncoder(nn.Module):
"""Encode an RGB image into a 1D feature vector.
@@ -743,19 +694,15 @@ class VQBeTRgbEncoder(nn.Module):
# Set up pooling and final layers.
# Use a dry run to get the feature map shape.
# The dummy input should take the number of image channels from `config.input_shapes` and it should
# The dummy input should take the number of image channels from `config.image_features` and it should
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
# height and width from `config.input_shapes`.
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
assert len(image_keys) == 1
image_key = image_keys[0]
dummy_input_h_w = (
config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:]
)
dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w))
with torch.inference_mode():
dummy_feature_map = self.backbone(dummy_input)
feature_map_shape = tuple(dummy_feature_map.shape[1:])
# height and width from `config.image_features`.
images_shape = next(iter(config.image_features.values())).shape
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
self.feature_dim = config.spatial_softmax_num_keypoints * 2
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
@@ -844,7 +791,7 @@ class VqVae(nn.Module):
)
self.encoder = MLP(
in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size,
in_channels=self.config.action_feature.shape[0] * self.config.action_chunk_size,
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
@@ -856,7 +803,7 @@ class VqVae(nn.Module):
hidden_channels=[
config.vqvae_enc_hidden_dim,
config.vqvae_enc_hidden_dim,
self.config.output_shapes["action"][0] * self.config.action_chunk_size,
self.config.action_feature.shape[0] * self.config.action_chunk_size,
],
)
@@ -872,9 +819,9 @@ class VqVae(nn.Module):
# given latent vector, this function outputs the decoded action.
output = self.decoder(latent)
if self.config.action_chunk_size == 1:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
else:
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0])
return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0])
def get_code(self, state):
# in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181)