forked from tangger/lerobot
Fix unit tests, Refactor, Add pusht env, (TODO pusht replay buffer, image preprocessing)
This commit is contained in:
@@ -71,24 +71,35 @@ def eval(cfg: dict):
|
||||
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
|
||||
|
||||
env = make_env(cfg)
|
||||
policy = TDMPC(cfg)
|
||||
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
|
||||
policy.load(ckpt_path)
|
||||
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
if cfg.pretrained_model_path:
|
||||
policy = TDMPC(cfg)
|
||||
ckpt_path = (
|
||||
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
|
||||
)
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
policy.step = 25000
|
||||
elif "final" in cfg.pretrained_model_path:
|
||||
policy.step = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(ckpt_path)
|
||||
|
||||
policy = TensorDictModule(
|
||||
policy,
|
||||
in_keys=["observation", "step_count"],
|
||||
out_keys=["action"],
|
||||
)
|
||||
else:
|
||||
# when policy is None, rollout a random policy
|
||||
policy = None
|
||||
|
||||
# policy can be None to rollout a random policy
|
||||
metrics = eval_policy(
|
||||
env,
|
||||
policy=policy,
|
||||
num_episodes=20,
|
||||
save_video=False,
|
||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||
save_video=True,
|
||||
video_dir=Path("tmp/2023_02_19_pusht"),
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user