Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -20,6 +20,7 @@ def eval_policy(
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
fps: int = 15,
):
rewards = []
successes = []
@@ -55,7 +56,7 @@ def eval_policy(
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
imageio.mimsave(video_path, np.stack(ep_frames), fps=fps)
metrics = {
"avg_reward": np.nanmean(rewards),
@@ -74,16 +75,13 @@ def eval(cfg: dict):
if cfg.pretrained_model_path:
policy = TDMPC(cfg)
ckpt_path = (
"/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
)
if "offline" in cfg.pretrained_model_path:
policy.step = 25000
elif "final" in cfg.pretrained_model_path:
policy.step = 100000
else:
raise NotImplementedError()
policy.load(ckpt_path)
policy.load(cfg.pretrained_model_path)
policy = TensorDictModule(
policy,
@@ -99,7 +97,8 @@ def eval(cfg: dict):
policy=policy,
num_episodes=20,
save_video=True,
video_dir=Path("tmp/2023_02_19_pusht"),
video_dir=Path(cfg.video_dir),
fps=cfg.fps,
)
print(metrics)

View File

@@ -1,6 +1,7 @@
import pickle
from pathlib import Path
import hydra
import imageio
import simxarm
import torch
@@ -10,30 +11,25 @@ from torchrl.data.replay_buffers import (
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.datasets.factory import make_offline_buffer
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict):
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
shuffle=False,
)
dataset = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS)
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
@@ -44,16 +40,23 @@ def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
dim=0,
)
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
video_dir = Path(cfg.video_dir)
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
# ran out of episodes
if dataset._sampler._sample_list.numel() == 0:
if offline_buffer._sampler._sample_list.numel() == 0:
break
if __name__ == "__main__":
visualize_simxarm_dataset()
visualize_dataset()