22 lines
669 B
Python
22 lines
669 B
Python
from lerobot.common.policies.tdmpc import TDMPC
|
|
|
|
|
|
def make_policy(cfg):
|
|
if cfg.policy == "tdmpc":
|
|
policy = TDMPC(cfg)
|
|
else:
|
|
raise ValueError(cfg.policy)
|
|
|
|
if cfg.pretrained_model_path:
|
|
# TODO(rcadene): hack for old pretrained models from fowm
|
|
if cfg.policy == "tdmpc" and "fowm" in cfg.pretrained_model_path:
|
|
if "offline" in cfg.pretrained_model_path:
|
|
policy.step[0] = 25000
|
|
elif "final" in cfg.pretrained_model_path:
|
|
policy.step[0] = 100000
|
|
else:
|
|
raise NotImplementedError()
|
|
policy.load(cfg.pretrained_model_path)
|
|
|
|
return policy
|