From 812feac7d7bcb3a6d8123fa69a9ac2e92d82bd04 Mon Sep 17 00:00:00 2001 From: Dana Date: Tue, 3 Jun 2025 14:29:56 +0200 Subject: [PATCH] fixes for merging --- lerobot/common/policies/smolvla/configuration_smolvla.py | 3 +++ lerobot/common/policies/smolvla/modeling_smolvla.py | 4 ++-- tests/test_available.py | 2 ++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/smolvla/configuration_smolvla.py b/lerobot/common/policies/smolvla/configuration_smolvla.py index 822c0f40..a5de7264 100644 --- a/lerobot/common/policies/smolvla/configuration_smolvla.py +++ b/lerobot/common/policies/smolvla/configuration_smolvla.py @@ -98,6 +98,9 @@ class SmolVLAConfig(PreTrainedConfig): self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM) + min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding + max_period: float = 4.0 + def __post_init__(self): super().__post_init__() diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py index ac56205a..23042d5b 100644 --- a/lerobot/common/policies/smolvla/modeling_smolvla.py +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -17,7 +17,7 @@ """ SmolVLA: -[Paper]() +[Paper](https://huggingface.co/papers/2506.01844) Designed by Hugging Face. @@ -656,7 +656,7 @@ class VLAFlowMatching(nn.Module): dtype = action_emb.dtype # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] time_emb = create_sinusoidal_pos_embedding( - timestep, self.vlm_with_expert.expert_hidden_size, min_period=4e-3, max_period=4.0, device=device + timestep, self.vlm_with_expert.expert_hidden_size, self.config.min_period, self.config.max_period, device=device ) time_emb = time_emb.type(dtype=dtype) diff --git a/tests/test_available.py b/tests/test_available.py index f4f9d4de..a207dea4 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -23,6 +23,7 @@ from lerobot.common.policies.act.modeling_act import ACTPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy +from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy from tests.utils import require_env @@ -50,6 +51,7 @@ def test_available_policies(): DiffusionPolicy, TDMPCPolicy, VQBeTPolicy, + SmolVLAPolicy ] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies