This commit is contained in:
Alexander Soare
2024-04-12 16:55:32 +01:00
parent 5bd953e8e7
commit 55e484124a
5 changed files with 758 additions and 51 deletions

View File

@@ -1,3 +1,8 @@
import inspect
from lerobot.common.utils import get_safe_torch_device
def make_policy(cfg):
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
@@ -21,10 +26,16 @@ def make_policy(cfg):
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
from lerobot.common.policies.act.configuration_act import ActConfig
from lerobot.common.policies.act.modeling_act import ActPolicy
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
policy.to(cfg.device)
expected_kwargs = set(inspect.signature(ActConfig).parameters)
assert set(cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActConfig(**{k: v for k, v in cfg.policy.items() if k in expected_kwargs})
policy = ActPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device))
else:
raise ValueError(cfg.policy.name)