Add policies/factory, Add test, Add _self_ in config

This commit is contained in:
Cadene
2024-02-25 10:50:23 +00:00
parent 64b5920e94
commit 598bb496b0
13 changed files with 61 additions and 38 deletions

View File

@@ -10,7 +10,7 @@ from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed
@@ -111,15 +111,7 @@ def eval(cfg: dict, out_dir=None):
env = make_env(cfg)
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
policy = make_policy(cfg)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],

View File

@@ -11,10 +11,9 @@ from torchrl.data.datasets.openx import OpenXExperienceReplay
from torchrl.data.replay_buffers import PrioritizedSliceSampler
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger
from lerobot.common.tdmpc import TDMPC
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import set_seed
from lerobot.scripts.eval import eval_policy
@@ -51,17 +50,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
print(colored("Work dir:", "yellow", attrs=["bold"]), out_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
if cfg.pretrained_model_path:
# TODO(rcadene): hack for old pretrained models from fowm
if "fowm" in cfg.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
policy.step[0] = 25000
elif "final" in cfg.pretrained_model_path:
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.pretrained_model_path)
policy = make_policy(cfg)
td_policy = TensorDictModule(
policy,