Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -20,6 +20,7 @@ def eval_policy(
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
fps: int = 15,
):
rewards = []
successes = []
@@ -55,7 +56,7 @@ def eval_policy(
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
metrics = {
"avg_reward": np.nanmean(rewards),
@@ -74,16 +75,13 @@ def eval(cfg: dict):
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
policy.load(cfg.pretrained_model_path)
policy = TensorDictModule(
policy,
@@ -99,7 +97,8 @@ def eval(cfg: dict):
policy=policy,
num_episodes=20,
save_video=True,
video_dir=Path("tmp/2023_02_19_pusht"),
video_dir=Path(cfg.video_dir),
fps=cfg.fps,
)
print(metrics)