117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
import einops
|
|
import hydra
|
|
import imageio
|
|
import torch
|
|
from torchrl.data.replay_buffers import (
|
|
SamplerWithoutReplacement,
|
|
)
|
|
|
|
from lerobot.common.datasets.factory import make_offline_buffer
|
|
from lerobot.common.logger import log_output_dir
|
|
from lerobot.common.utils import init_logging
|
|
|
|
NUM_EPISODES_TO_RENDER = 50
|
|
MAX_NUM_STEPS = 1000
|
|
FIRST_FRAME = 0
|
|
|
|
|
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
|
def visualize_dataset_cli(cfg: dict):
|
|
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
|
|
|
|
|
def cat_and_write_video(video_path, frames, fps):
|
|
# Expects images in [0, 1].
|
|
frames = torch.cat(frames)
|
|
assert frames.max() <= 1 and frames.min() >= 0
|
|
frames = (255 * frames).to(dtype=torch.uint8)
|
|
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
|
imageio.mimsave(video_path, frames, fps=fps)
|
|
|
|
|
|
def visualize_dataset(cfg: dict, out_dir=None):
|
|
if out_dir is None:
|
|
raise NotImplementedError()
|
|
|
|
init_logging()
|
|
log_output_dir(out_dir)
|
|
|
|
# we expect frames of each episode to be stored next to each others sequentially
|
|
sampler = SamplerWithoutReplacement(
|
|
shuffle=False,
|
|
)
|
|
|
|
logging.info("make_offline_buffer")
|
|
offline_buffer = make_offline_buffer(
|
|
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
|
)
|
|
|
|
logging.info("Start rendering episodes from offline buffer")
|
|
|
|
threads = []
|
|
frames = {}
|
|
current_ep_idx = 0
|
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
|
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
|
# TODO(rcadene): make it work with bsize > 1
|
|
ep_td = offline_buffer.sample(1)
|
|
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
|
|
|
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
|
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
|
new_episode = ep_idx != current_ep_idx
|
|
|
|
if new_episode:
|
|
logging.info(f"Visualizing episode {current_ep_idx}")
|
|
|
|
for im_key in offline_buffer.image_keys:
|
|
if new_episode or no_more_frames:
|
|
# append last observed frames (the ones after last action taken)
|
|
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
|
|
|
|
video_dir = Path(out_dir) / "visualize_dataset"
|
|
video_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
if len(offline_buffer.image_keys) > 1:
|
|
camera = im_key[-1]
|
|
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
|
else:
|
|
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
|
|
|
thread = threading.Thread(
|
|
target=cat_and_write_video,
|
|
args=(str(video_path), frames[im_key], cfg.fps),
|
|
)
|
|
thread.start()
|
|
threads.append(thread)
|
|
|
|
current_ep_idx = ep_idx
|
|
|
|
# reset list of frames
|
|
del frames[im_key]
|
|
|
|
# append current cameras images to list of frames
|
|
if im_key not in frames:
|
|
frames[im_key] = []
|
|
frames[im_key].append(ep_td[im_key])
|
|
|
|
if no_more_frames:
|
|
logging.info("Ran out of frames")
|
|
break
|
|
|
|
if current_ep_idx == NUM_EPISODES_TO_RENDER:
|
|
break
|
|
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
logging.info("End of visualize_dataset")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
visualize_dataset_cli()
|