Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act

This commit is contained in:
Alexander Soare
2024-04-09 08:36:28 +01:00
13 changed files with 109 additions and 247 deletions

View File

@@ -16,18 +16,15 @@ def make_policy(cfg):
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_obs_steps=cfg.n_obs_steps,
n_action_steps=cfg.n_action_steps,
# n_obs_steps=cfg.n_obs_steps,
# n_action_steps=cfg.n_action_steps,
**cfg.policy,
)
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy,
cfg.device,
n_action_steps=cfg.n_action_steps,
)
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
policy.to(cfg.device)
else:
raise ValueError(cfg.policy.name)