Sanitize cfg.policy, Fix skip_frame pusht.yaml

This commit is contained in:
Cadene
2024-02-25 11:09:02 +00:00
parent fc4b98544b
commit e765e26b0b
3 changed files with 78 additions and 72 deletions

View File

@@ -2,20 +2,20 @@ from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg):
if cfg.policy == "tdmpc":
policy = TDMPC(cfg)
if cfg.policy.name == "tdmpc":
policy = TDMPC(cfg.policy)
else:
raise ValueError(cfg.policy)
raise ValueError(cfg.policy.name)
if cfg.pretrained_model_path:
if cfg.policy.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
policy.load(cfg.policy.pretrained_model_path)
return policy