Files
lerobot_piper/lerobot/common/policies/factory.py

22 lines
669 B
Python

from lerobot.common.policies.tdmpc import TDMPC
def make_policy(cfg):
if cfg.policy == "tdmpc":
policy = TDMPC(cfg)
else:
raise ValueError(cfg.policy)
if cfg.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
return policy