Merge remote-tracking branch 'origin/user/rcadene/2024_03_31_remove_torchrl' into user/rcadene/2024_03_31_remove_torchrl

This commit is contained in:
Cadene
2024-04-10 11:34:51 +00:00
19 changed files with 1082 additions and 1805 deletions

View File

@@ -23,11 +23,7 @@ def make_policy(cfg):
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy,
n_obs_steps=cfg.policy.n_obs_steps,
n_action_steps=cfg.policy.n_action_steps,
)
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
policy.to(cfg.device)
else:
raise ValueError(cfg.policy.name)