From 8d6acb3a4f72c5dd1fd0c7fe7ea9cc7ba9133024 Mon Sep 17 00:00:00 2001 From: mshukor Date: Wed, 28 May 2025 14:14:51 +0200 Subject: [PATCH] load model --- lerobot/common/policies/__init__.py | 1 + lerobot/common/policies/factory.py | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index b73ba5f4..9cb0f623 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -15,5 +15,6 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a3..3aade066 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -27,6 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.configs.policies import PreTrainedConfig @@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "smolvla": + from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + return SmolVLAPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "smolvla": + return SmolVLAConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.")