revision 1

This commit is contained in:
Alexander Soare
2024-04-15 10:56:43 +01:00
parent 40d417ef60
commit 30023535f9
6 changed files with 24 additions and 20 deletions

View File

@@ -28,21 +28,21 @@ def make_policy(cfg):
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.configuration_act import ActConfig
from lerobot.common.policies.act.modeling_act import ActPolicy
from lerobot.common.policies.act.configuration_act import ActionChunkingTransformerConfig
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
expected_kwargs = set(inspect.signature(ActConfig).parameters)
expected_kwargs = set(inspect.signature(ActionChunkingTransformerConfig).parameters)
assert set(cfg.policy).issuperset(
expected_kwargs
), f"Hydra config is missing arguments: {set(cfg.policy).difference(expected_kwargs)}"
policy_cfg = ActConfig(
policy_cfg = ActionChunkingTransformerConfig(
**{
k: v
for k, v in OmegaConf.to_container(cfg.policy, resolve=True).items()
if k in expected_kwargs
}
)
policy = ActPolicy(policy_cfg)
policy = ActionChunkingTransformerPolicy(policy_cfg)
policy.to(get_safe_torch_device(cfg.device))
else:
raise ValueError(cfg.policy.name)