Files
lerobot_piper/lerobot/scripts/visualize_dataset.py
Cadene 1cdfbc8b52 WIP
WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
2024-04-04 15:31:03 +00:00

131 lines
4.2 KiB
Python

import logging
import threading
from pathlib import Path
import einops
import hydra
import imageio
import torch
from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
NUM_EPISODES_TO_RENDER = 50
MAX_NUM_STEPS = 1000
FIRST_FRAME = 0
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def visualize_dataset_cli(cfg: dict):
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
def visualize_dataset(cfg: dict, out_dir=None):
if out_dir is None:
raise NotImplementedError()
init_logging()
log_output_dir(out_dir)
# we expect frames of each episode to be stored next to each others sequentially
sampler = SamplerWithoutReplacement(
shuffle=False,
)
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
def render_dataset(dataset, out_dir, max_num_samples, fps):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True)
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
current_ep_idx = ep_idx
# reset list of frames
del frames[im_key]
if num_frames_left == 0:
logging.info("Ran out of frames")
break
if current_ep_idx == NUM_EPISODES_TO_RENDER:
break
for thread in threads:
thread.join()
logging.info("End of visualize_dataset")
return video_paths
if __name__ == "__main__":
visualize_dataset_cli()