load model
This commit is contained in:
@@ -15,5 +15,6 @@
|
|||||||
from .act.configuration_act import ACTConfig as ACTConfig
|
from .act.configuration_act import ACTConfig as ACTConfig
|
||||||
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
|
||||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||||
|
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC
|
|||||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||||
|
from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
@@ -59,6 +60,10 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
|||||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||||
|
|
||||||
return PI0FASTPolicy
|
return PI0FASTPolicy
|
||||||
|
elif name == "smolvla":
|
||||||
|
from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||||
|
|
||||||
|
return SmolVLAPolicy
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||||
|
|
||||||
@@ -76,6 +81,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
return PI0Config(**kwargs)
|
return PI0Config(**kwargs)
|
||||||
elif policy_type == "pi0fast":
|
elif policy_type == "pi0fast":
|
||||||
return PI0FASTConfig(**kwargs)
|
return PI0FASTConfig(**kwargs)
|
||||||
|
elif policy_type == "smolvla":
|
||||||
|
return SmolVLAConfig(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user