Add pusht dataset (TODO verify reward is aligned), Refactor visualize_dataset, Add video_dir, fps, state_dim, action_dim to config (Training works)

This commit is contained in:
Cadene
2024-02-21 00:49:40 +00:00
parent 3dc14b5576
commit ece89730e6
8 changed files with 222 additions and 111 deletions

View File

@@ -1,6 +1,7 @@
import pickle
from pathlib import Path
import hydra
import imageio
import simxarm
import torch
@@ -10,30 +11,25 @@ from torchrl.data.replay_buffers import (
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
from lerobot.common.datasets.factory import make_offline_buffer
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset(cfg: dict):
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
shuffle=False,
)
dataset = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
offline_buffer = make_offline_buffer(cfg, sampler)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS)
episode = offline_buffer.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
@@ -44,16 +40,23 @@ def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
dim=0,
)
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
video_dir = Path(cfg.video_dir)
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
assert ep_frames.min().item() >= 0
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
assert ep_frames.max().item() <= 255
ep_frames = ep_frames.type(torch.uint8)
imageio.mimsave(
video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps
)
# ran out of episodes
if dataset._sampler._sample_list.numel() == 0:
if offline_buffer._sampler._sample_list.numel() == 0:
break
if __name__ == "__main__":
visualize_simxarm_dataset()
visualize_dataset()