Add policies/factory, Add test, Add _self_ in config

This commit is contained in:
Cadene
2024-02-25 10:50:23 +00:00
parent 64b5920e94
commit 598bb496b0
13 changed files with 61 additions and 38 deletions

View File

@@ -10,7 +10,7 @@ from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed
@@ -111,15 +111,7 @@ def eval(cfg: dict, out_dir=None):
env = make_env(cfg)
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
policy = make_policy(cfg)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],