add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import hydra
|
|
import imageio
|
|
import torch
|
|
from torchrl.data.replay_buffers import (
|
|
SamplerWithoutReplacement,
|
|
)
|
|
|
|
from lerobot.common.datasets.factory import make_offline_buffer
|
|
from lerobot.common.logger import log_output_dir
|
|
from lerobot.common.utils import init_logging
|
|
|
|
NUM_EPISODES_TO_RENDER = 50
|
|
MAX_NUM_STEPS = 1000
|
|
FIRST_FRAME = 0
|
|
|
|
|
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
|
def visualize_dataset_cli(cfg: dict):
|
|
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
|
|
|
|
|
def cat_and_write_video(video_path, frames, fps):
|
|
# Expects images in [0, 255].
|
|
frames = torch.cat(frames)
|
|
assert frames.max() <= 1 and frames.min() >= 0
|
|
frames = (255 * frames).to(dtype=torch.uint8)
|
|
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
|
imageio.mimsave(video_path, frames, fps=fps)
|
|
|
|
|
|
def visualize_dataset(cfg: dict, out_dir=None):
|
|
if out_dir is None:
|
|
raise NotImplementedError()
|
|
|
|
init_logging()
|
|
log_output_dir(out_dir)
|
|
|
|
# we expect frames of each episode to be stored next to each others sequentially
|
|
sampler = SamplerWithoutReplacement(
|
|
shuffle=False,
|
|
)
|
|
|
|
logging.info("make_offline_buffer")
|
|
offline_buffer = make_offline_buffer(
|
|
cfg,
|
|
overwrite_sampler=sampler,
|
|
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
|
normalize=False,
|
|
overwrite_batch_size=1,
|
|
overwrite_prefetch=12,
|
|
)
|
|
|
|
logging.info("Start rendering episodes from offline buffer")
|
|
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
|
for video_path in video_paths:
|
|
logging.info(video_path)
|
|
|
|
|
|
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
|
out_dir = Path(out_dir)
|
|
video_paths = []
|
|
threads = []
|
|
frames = {}
|
|
current_ep_idx = 0
|
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
|
for i in range(max_num_samples):
|
|
# TODO(rcadene): make it work with bsize > 1
|
|
ep_td = offline_buffer.sample(1)
|
|
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
|
|
|
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
|
num_frames_left = offline_buffer._sampler._sample_list.numel()
|
|
episode_is_done = ep_idx != current_ep_idx
|
|
|
|
if episode_is_done:
|
|
logging.info(f"Rendering episode {current_ep_idx}")
|
|
|
|
for im_key in offline_buffer.image_keys:
|
|
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
|
# when first frame of episode, initialize frames dict
|
|
if im_key not in frames:
|
|
frames[im_key] = []
|
|
# add current frame to list of frames to render
|
|
frames[im_key].append(ep_td[im_key])
|
|
else:
|
|
# When episode has no more frame in its list of observation,
|
|
# one frame still remains. It is the result of the last action taken.
|
|
# It is stored in `"next"`, so we add it to the list of frames to render.
|
|
frames[im_key].append(ep_td["next"][im_key])
|
|
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
if len(offline_buffer.image_keys) > 1:
|
|
camera = im_key[-1]
|
|
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
|
else:
|
|
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
|
video_paths.append(str(video_path))
|
|
|
|
thread = threading.Thread(
|
|
target=cat_and_write_video,
|
|
args=(str(video_path), frames[im_key], fps),
|
|
)
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
current_ep_idx = ep_idx
|
|
|
|
# reset list of frames
|
|
del frames[im_key]
|
|
|
|
if num_frames_left == 0:
|
|
logging.info("Ran out of frames")
|
|
break
|
|
|
|
if current_ep_idx == NUM_EPISODES_TO_RENDER:
|
|
break
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
logging.info("End of visualize_dataset")
|
|
return video_paths
|
|
|
|
|
|
if __name__ == "__main__":
|
|
visualize_dataset_cli()
|