diff --git a/lerobot/__init__.py b/lerobot/__init__.py index f8acafce3..fe7e037d9 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -168,12 +168,7 @@ available_datasets = sorted( ) # lists all available policies from `lerobot/common/policies` -available_policies = [ - "act", - "diffusion", - "tdmpc", - "vqbet", -] +available_policies = ["act", "diffusion", "tdmpc", "vqbet"] # lists all available robots from `lerobot/common/robot_devices/robots` available_robots = [ diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index b73ba5f4e..9cb0f6234 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -15,5 +15,6 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a35..3aade0665 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -27,6 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.configs.policies import PreTrainedConfig @@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "smolvla": + from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + return SmolVLAPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "smolvla": + return SmolVLAConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/smolvla/configuration_smolvla.py b/lerobot/common/policies/smolvla/configuration_smolvla.py new file mode 100644 index 000000000..5996cf2e7 --- /dev/null +++ b/lerobot/common/policies/smolvla/configuration_smolvla.py @@ -0,0 +1,154 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@PreTrainedConfig.register_subclass("smolvla") +@dataclass +class SmolVLAConfig(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (512, 512) + + # Add empty images. Used by smolvla_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + + # Decoding + num_steps: int = 10 + + # Attention utils + use_cache: bool = True + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = True + train_state_proj: bool = True + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + optimizer_grad_clip_norm: float = 10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. + load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights + + add_image_special_tokens: bool = False # Whether to use special image tokens around image features. + + attention_mode: str = "cross_attn" + + prefix_length: int = -1 + + pad_language_to: str = "longest" # "max_length" + + num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers. + num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers) + self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers + expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM) + + min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding + max_period: float = 4.0 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.use_delta_joint_actions_aloha: + raise NotImplementedError( + "`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot." + ) + + def validate_features(self) -> None: + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list: + return [0] + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py new file mode 100644 index 000000000..efdf602d4 --- /dev/null +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SmolVLA: + +[Paper](https://huggingface.co/papers/2506.01844) + +Designed by Hugging Face. + +Install smolvla extra dependencies: +```bash +pip install -e ".[smolvla]" +``` + +Example of finetuning the smolvla pretrained model (`smolvla_base`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/smolvla_base \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, +and an action expert. +```bash +python lerobot/scripts/train.py \ +--policy.type=smolvla \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of using the smolvla pretrained model outside LeRobot training framework: +```python +policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") +``` + +""" + +import math +from collections import deque + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoProcessor + +from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.policies.normalize import ( + Normalize, + Unnormalize, +) +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel +from lerobot.common.policies.utils import ( + populate_queues, +) +from lerobot.common.utils.utils import get_safe_dtype + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb + + +def sample_beta(alpha, beta, bsize, device): + gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) + gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) + return gamma1 / (gamma1 + gamma2) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + return att_2d_masks + + +def resize_with_pad(img, width, height, pad_value=-1): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_vector(vector, new_dim): + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with smolvla which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class SmolVLAPolicy(PreTrainedPolicy): + """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" + + config_class = SmolVLAConfig + name = "smolvla" + + def __init__( + self, + config: SmolVLAConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer + self.model = VLAFlowMatching(config) + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def get_optim_params(self) -> dict: + return self.parameters() + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + + batch = self.normalize_inputs(batch) + + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._queues[ACTION]) == 0: + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions( + images, img_masks, lang_tokens, lang_masks, state, noise=noise + ) + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({"action": actions})["action"] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + return self._queues[ACTION].popleft() + + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss""" + if self.config.adapt_to_pi_aloha: + batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + actions = self.prepare_action(batch) + actions_is_pad = batch.get("actions_id_pad") + loss_dict = {} + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + loss_dict["losses_after_forward"] = losses.clone() + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + losses = losses * in_episode_bound.unsqueeze(-1) + loss_dict["losses_after_in_ep_bound"] = losses.clone() + + # Remove padding + losses = losses[:, :, : self.config.max_action_dim] + loss_dict["losses_after_rm_padding"] = losses.clone() + + # For backward pass + loss = losses.mean() + # For backward pass + loss_dict["loss"] = loss + return loss, loss_dict + + def prepare_images(self, batch): + """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and + convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + """ + images = [] + img_masks = [] + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key] + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + if f"{key}_padding_mask" in batch: + mask = batch[f"{key}_padding_mask"].bool() + else: + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + return images, img_masks + + def prepare_language(self, batch) -> tuple[Tensor, Tensor]: + """Tokenize the text input""" + device = batch[OBS_ROBOT].device + tasks = batch["task"] + if len(tasks) == 1: + tasks = [tasks[0] for _ in range(batch[OBS_ROBOT].shape[0])] + + tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + tokenized_prompt = self.language_tokenizer.__call__( + tasks, + padding=self.config.pad_language_to, + padding_side="right", + max_length=self.config.tokenizer_max_length, + return_tensors="pt", + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + def prepare_state(self, batch): + """Pad state""" + state = batch[OBS_ROBOT][:, -1, :] if batch[OBS_ROBOT].ndim > 2 else batch[OBS_ROBOT] + state = pad_vector(state, self.config.max_state_dim) + return state + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + +def pad_tensor(tensor, max_len, pad_value=0): + """ + Efficiently pads a tensor along sequence dimension to match max_len. + + Args: + tensor (torch.Tensor): Shape (B, L, ...) or (B, L). + max_len (int): Fixed sequence length. + pad_value (int/float): Value for padding. + + Returns: + torch.Tensor: Shape (B, max_len, ...) or (B, max_len). + """ + b, d = tensor.shape[:2] + + # Create a padded tensor of max_len and copy the existing values + padded_tensor = torch.full( + (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, :d] = tensor # Efficient in-place copy + + return padded_tensor + + +class VLAFlowMatching(nn.Module): + """ + SmolVLA + + [Paper]() + + Designed by Hugging Face. + ┌──────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌─────────┐ ┌─|────┐ │ + │ | │────► │ │ │ + │ | │ kv │ │ │ + │ | │────► │Action│ │ + │ | VLM │cache │Expert│ | + │ │ │────► | │ │ + │ │ │ │ │ │ + │ └▲──▲───▲─┘ └───▲──┘ | + │ │ | | │ | + │ | | | noise │ + │ │ │ state │ + │ │ language tokens │ + │ image(s) │ + └──────────────────────────────┘ + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.vlm_with_expert = SmolVLMWithExpertModel( + model_id=self.config.vlm_model_name, + freeze_vision_encoder=self.config.freeze_vision_encoder, + train_expert_only=self.config.train_expert_only, + load_vlm_weights=self.config.load_vlm_weights, + attention_mode=self.config.attention_mode, + num_expert_layers=self.config.num_expert_layers, + num_vlm_layers=self.config.num_vlm_layers, + self_attn_every_n_layers=self.config.self_attn_every_n_layers, + expert_width_multiplier=self.config.expert_width_multiplier, + ) + self.state_proj = nn.Linear( + self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size + ) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) + self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) + + self.action_time_mlp_in = nn.Linear( + self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size + ) + self.action_time_mlp_out = nn.Linear( + self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size + ) + + self.set_requires_grad() + self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id + self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id + self.global_image_start_token = torch.tensor( + [self.fake_image_token, self.global_image_token], dtype=torch.long + ) + + self.add_image_special_tokens = self.config.add_image_special_tokens + self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) + self.prefix_length = self.config.prefix_length + + def set_requires_grad(self): + for params in self.state_proj.parameters(): + params.requires_grad = self.config.train_state_proj + + def sample_noise(self, shape, device): + noise = torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + return noise + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for SmolVLM transformer processing. + """ + embs = [] + pad_masks = [] + att_masks = [] + for _img_idx, ( + img, + img_mask, + ) in enumerate(zip(images, img_masks, strict=False)): + if self.add_image_special_tokens: + image_start_token = ( + self.vlm_with_expert.embed_language_tokens( + self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_start_mask = torch.ones_like( + image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device + ) + att_masks += [0] * (image_start_mask.shape[-1]) + embs.append(image_start_token) + pad_masks.append(image_start_mask) + + img_emb = self.vlm_with_expert.embed_image(img) + img_emb = img_emb + + # Normalize image embeddings + img_emb_dim = img_emb.shape[-1] + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + att_masks += [0] * (num_img_embs) + if self.add_image_special_tokens: + image_end_token = ( + self.vlm_with_expert.embed_language_tokens( + self.image_end_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_end_mask = torch.ones_like( + image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device + ) + embs.append(image_end_token) + pad_masks.append(image_end_mask) + att_masks += [0] * (image_end_mask.shape[1]) + lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + state_emb = self.state_proj(state) + state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + embs.append(state_emb) + bsize = state_emb.shape[0] + device = state_emb.device + + states_seq_len = state_emb.shape[1] + state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] * (states_seq_len) + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :] + + seq_len = pad_masks.shape[1] + if seq_len < self.prefix_length: + embs = pad_tensor(embs, self.prefix_length, pad_value=0) + pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) + att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) + + att_masks = att_masks.expand(bsize, -1) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Fuse timestep + action information using an MLP + action_emb = self.action_in_proj(noisy_actions) + device = action_emb.device + bsize = action_emb.shape[0] + dtype = action_emb.dtype + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.vlm_with_expert.expert_hidden_size, + self.config.min_period, + self.config.max_period, + device=device, + ) + time_emb = time_emb.type(dtype=dtype) + + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] * self.config.chunk_size + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + return embs, pad_masks, att_masks + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + (_, suffix_out), _ = self.vlm_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + suffix_out = suffix_out[:, -self.config.chunk_size :] + # Original openpi code, upcast attention output + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + losses = F.mse_loss(u_t, v_t, reduction="none") + return losses + + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + # Compute image and language key value cache + _, past_key_values = self.vlm_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=self.config.use_cache, + fill_kv_cache=True, + ) + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + # Euler step + x_t += dt * v_t + time += dt + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.vlm_with_expert.forward( + attention_mask=full_att_2d_masks, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=self.config.use_cache, + fill_kv_cache=False, + ) + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + return v_t diff --git a/lerobot/common/policies/smolvla/smolvlm_with_expert.py b/lerobot/common/policies/smolvla/smolvlm_with_expert.py new file mode 100644 index 000000000..07eae8089 --- /dev/null +++ b/lerobot/common/policies/smolvla/smolvlm_with_expert.py @@ -0,0 +1,550 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Optional + +import torch +from torch import nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForImageTextToText, + AutoProcessor, + SmolVLMForConditionalGeneration, +) + + +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim + + +class SmolVLMWithExpertModel(nn.Module): + def __init__( + self, + model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", + load_vlm_weights: bool = True, + train_expert_only: bool = True, + freeze_vision_encoder: bool = False, + attention_mode: str = "self_attn", + num_expert_layers: int = -1, + num_vlm_layers: int = -1, + self_attn_every_n_layers: int = -1, + expert_width_multiplier: float = 0.5, + ): + super().__init__() + if load_vlm_weights: + print(f"Loading {model_id} weights ...") + self.vlm = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", + low_cpu_mem_usage=True, + ) + config = self.vlm.config + else: + config = AutoConfig.from_pretrained(model_id) + self.vlm = SmolVLMForConditionalGeneration(config=config) + self.processor = AutoProcessor.from_pretrained(model_id) + if num_vlm_layers > 0: + print(f"Reducing the number of VLM layers to {num_vlm_layers} ...") + self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers] + self.num_vlm_layers = len(self.get_vlm_model().text_model.layers) + self.config = config + # Smaller lm expert + lm_expert_config = copy.deepcopy(config.text_config) + hidden_size = lm_expert_config.hidden_size + lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 + lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) + lm_expert_config.num_hidden_layers = self.num_vlm_layers + if num_expert_layers > 0: + assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( + f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" + ) + lm_expert_config.num_hidden_layers = num_expert_layers + self.lm_expert = AutoModel.from_config(lm_expert_config) + + self.num_expert_layers = len(self.lm_expert.layers) + self.self_attn_every_n_layers = self_attn_every_n_layers + if "cross" in attention_mode: + # Reshape qkv projections to have the same input dimension as the vlm + for layer_idx in range(len(self.lm_expert.layers)): + if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: + continue + self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + # Remove unused embed_tokens + self.lm_expert.embed_tokens = None + + self.num_attention_heads = self.config.text_config.num_attention_heads + self.num_key_value_heads = self.config.text_config.num_key_value_heads + + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_mode = attention_mode + self.expert_hidden_size = lm_expert_config.hidden_size + self.set_requires_grad() + + def get_vlm_model(self): + return self.vlm.model + + def set_requires_grad(self): + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + for params in self.get_vlm_model().vision_model.parameters(): + params.requires_grad = False + if self.train_expert_only: + self.vlm.eval() + for params in self.vlm.parameters(): + params.requires_grad = False + else: + # To avoid unused params issue with distributed training + last_layers = [self.num_vlm_layers - 1] + if ( + self.num_vlm_layers != self.num_expert_layers + and self.num_vlm_layers % self.num_expert_layers == 0 + ): + last_layers.append(self.num_vlm_layers - 2) + frozen_layers = [ + "lm_head", + "text_model.model.norm.weight", + ] + for layer in last_layers: + frozen_layers.append(f"text_model.model.layers.{layer}.") + + for name, params in self.vlm.named_parameters(): + if any(k in name for k in frozen_layers): + params.requires_grad = False + # To avoid unused params issue with distributed training + for name, params in self.lm_expert.named_parameters(): + if "lm_head" in name: + params.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + + if self.train_expert_only: + self.vlm.eval() + + def embed_image(self, image: torch.Tensor): + patch_attention_mask = None + # Get sequence from the vision encoder + image_hidden_states = ( + self.get_vlm_model() + .vision_model( + pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), + patch_attention_mask=patch_attention_mask, + ) + .last_hidden_state + ) + # Modality projection & resampling + image_hidden_states = self.get_vlm_model().connector(image_hidden_states) + return image_hidden_states + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.get_vlm_model().text_model.get_input_embeddings()(tokens) + + def forward_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + query_states = [] + key_states = [] + value_states = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + if hidden_states is None or layer is None: + continue + hidden_states = layer.input_layernorm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + seq_len = query_states.shape[1] + if seq_len < position_ids.shape[1]: + _position_ids = position_ids[:, :seq_len] + _attention_mask = attention_mask[:, :seq_len, :seq_len] + else: + _position_ids = position_ids + _attention_mask = attention_mask + + attention_mask_ = _attention_mask + position_ids_ = _position_ids + + query_states = apply_rope(query_states, position_ids_) + key_states = apply_rope(key_states, position_ids_) + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) + + attention_interface = self.get_attention_interface() + + att_output = attention_interface( + attention_mask_, batch_size, head_dim, query_states, key_states, value_states + ) + return [att_output], past_key_values + + def forward_cross_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + attention_interface = self.get_attention_interface() + + att_outputs = [] + assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( + f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" + ) + + if len(inputs_embeds) == 2 and not past_key_values: + # Prefix attention + seq_len = inputs_embeds[0].shape[1] + position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] + prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] + + layer = model_layers[0][layer_idx] + + hidden_states = layer.input_layernorm(inputs_embeds[0]) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + # B,L,H,D with L sequence length, H number of heads, D head dim + query_states = apply_rope(query_state, position_id) + key_states = apply_rope(key_state, position_id) + + att_output = attention_interface( + prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states + ) + att_outputs.append(att_output) + else: + expert_position_id = position_ids + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = past_key_values[layer_idx]["key_states"] + value_states = past_key_values[layer_idx]["value_states"] + + # Expert + expert_layer = model_layers[1][layer_idx] + if expert_layer is not None: + expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1]) + + expert_input_shape = expert_hidden_states.shape[:-1] + expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim) + + expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) + expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) + + _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( + *key_states.shape[:2], -1 + ) + expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( + *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) # k_proj should have same dim as kv + + _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( + *value_states.shape[:2], -1 + ) + expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( + *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) + + expert_position_id = ( + expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values + ) # start from 0 + expert_attention_mask = attention_mask[ + :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : + ] # take into account kv + + expert_query_states = apply_rope(expert_query_state, expert_position_id) + + att_output = attention_interface( + expert_attention_mask, + batch_size, + head_dim, + expert_query_states, + expert_key_states, + expert_value_states, + ) + att_outputs.append(att_output) + else: + att_outputs.append(None) + + # att_output = att_output.to(dtype=models[i].dtype) + return att_outputs, past_key_values + + def get_model_layers(self, models: list) -> list: + vlm_layers = [] + expert_layers = [] + multiple_of = self.num_vlm_layers // self.num_expert_layers + for i in range(self.num_vlm_layers): + if multiple_of > 0 and i > 0 and i % multiple_of != 0: + expert_layer = None + else: + expert_layer_index = i // multiple_of if multiple_of > 0 else i + expert_layer = models[1].layers[expert_layer_index] + vlm_layers.append(models[0].layers[i]) + expert_layers.append(expert_layer) + return [vlm_layers, expert_layers] + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: List[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + fill_kv_cache: Optional[bool] = None, + ): + models = [self.get_vlm_model().text_model, self.lm_expert] + model_layers = self.get_model_layers(models) + for hidden_states in inputs_embeds: + # TODO this is very inefficient + # dtype is always the same, batch size too (if > 1 len) + # device could be trickier in multi gpu edge cases but that's it + if hidden_states is None: + continue + batch_size = hidden_states.shape[0] + + # RMSNorm + num_layers = self.num_vlm_layers + head_dim = self.vlm.config.text_config.head_dim + for layer_idx in range(num_layers): + if ( + fill_kv_cache + or "cross" not in self.attention_mode + or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) + ): + att_outputs, past_key_values = self.forward_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + else: + att_outputs, past_key_values = self.forward_cross_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + att_output = ( + att_outputs[i] if i < len(att_outputs) else att_outputs[0] + ) # in case of self_attn + if hidden_states is not None: + if layer is None: + outputs_embeds.append(hidden_states) + continue + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + att_out = att_output[:, start:end] + out_emb = layer.self_attn.o_proj(att_out) + + out_emb += hidden_states + after_first_residual = out_emb.clone() + + out_emb = layer.post_attention_layernorm(out_emb) + out_emb = layer.mlp(out_emb) + + out_emb += after_first_residual + + outputs_embeds.append(out_emb) + + start = end if len(att_outputs) == 1 else 0 + else: + outputs_embeds.append(None) + + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is not None: + out_emb = models[i].norm(hidden_states) + outputs_embeds.append(out_emb) + else: + outputs_embeds.append(None) + return outputs_embeds, past_key_values + + def get_attention_interface(self): + attention_interface = self.eager_attention_forward + return attention_interface + + def eager_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + num_att_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + # Attention here is upcasted to float32 to match the original eager implementation. + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= head_dim**-0.5 + + att_weights = att_weights.to(dtype=torch.float32) + big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + + att_output = att_output.permute(0, 2, 1, 3) + # we use -1 because sequence length can change + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + + return att_output diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 68f69487b..de10395c2 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -109,6 +109,10 @@ def predict_action(observation, policy, device, use_amp): ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: + # Skip all observations that are not tensors (e.g. text) + if not isinstance(observation[name], torch.Tensor): + continue + if "image" in name: observation[name] = observation[name].type(torch.float32) / 255 observation[name] = observation[name].permute(2, 0, 1).contiguous() @@ -256,7 +260,8 @@ def control_loop( else: observation = robot.capture_observation() action = None - + observation["task"] = [single_task] + observation["robot_type"] = [policy.robot_type] if hasattr(policy, "robot_type") else [""] if policy is not None: pred_action = predict_action( observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp @@ -267,6 +272,7 @@ def control_loop( action = {"action": action} if dataset is not None: + observation = {k: v for k, v in observation.items() if k not in ["task", "robot_type"]} frame = {**observation, **action, "task": single_task} dataset.add_frame(frame) diff --git a/pyproject.toml b/pyproject.toml index 64a75c0c1..568e58b8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] pi0 = ["transformers>=4.48.0"] +smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", diff --git a/tests/test_available.py b/tests/test_available.py index f4f9d4de6..a18b95ffa 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -45,12 +45,7 @@ def test_available_policies(): This test verifies that the class attribute `name` for all policies is consistent with those listed in `lerobot/__init__.py`. """ - policy_classes = [ - ACTPolicy, - DiffusionPolicy, - TDMPCPolicy, - VQBeTPolicy, - ] + policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies