Merge remote-tracking branch 'upstream/user/rcadene/2024_03_31_remove_torchrl' into refactor_act
This commit is contained in:
@@ -16,18 +16,15 @@ def make_policy(cfg):
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_obs_steps=cfg.n_obs_steps,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
# n_obs_steps=cfg.n_obs_steps,
|
||||
# n_action_steps=cfg.n_action_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy,
|
||||
cfg.device,
|
||||
n_action_steps=cfg.n_action_steps,
|
||||
)
|
||||
policy = ActionChunkingTransformerPolicy(cfg.policy, cfg.device)
|
||||
policy.to(cfg.device)
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user