Files
lerobot_piper/lerobot/scripts/visualize_dataset.py

60 lines
1.5 KiB
Python

import pickle
from pathlib import Path
import imageio
import simxarm
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
SliceSampler,
SliceSamplerWithoutReplacement,
)
from lerobot.common.datasets.simxarm import SimxarmExperienceReplay
def visualize_simxarm_dataset(dataset_id="xarm_lift_medium"):
sampler = SliceSamplerWithoutReplacement(
num_slices=1,
strict_length=False,
shuffle=False,
)
dataset = SimxarmExperienceReplay(
dataset_id,
# download="force",
download=True,
streaming=False,
root="data",
sampler=sampler,
)
NUM_EPISODES_TO_RENDER = 10
MAX_NUM_STEPS = 50
FIRST_FRAME = 0
for _ in range(NUM_EPISODES_TO_RENDER):
episode = dataset.sample(MAX_NUM_STEPS)
ep_idx = episode["episode"][FIRST_FRAME].item()
ep_frames = torch.cat(
[
episode["observation"]["image"][FIRST_FRAME][None, ...],
episode["next", "observation"]["image"],
],
dim=0,
)
video_dir = Path("tmp/2024_02_03_xarm_lift_medium")
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{ep_idx}.mp4"
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=15)
# ran out of episodes
if dataset._sampler._sample_list.numel() == 0:
break
if __name__ == "__main__":
visualize_simxarm_dataset()