forked from tangger/lerobot
test_examples are passing
This commit is contained in:
@@ -6,9 +6,6 @@ import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import (
|
||||
SamplerWithoutReplacement,
|
||||
)
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.logger import log_output_dir
|
||||
@@ -39,19 +36,11 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
init_logging()
|
||||
log_output_dir(out_dir)
|
||||
|
||||
# we expect frames of each episode to be stored next to each others sequentially
|
||||
sampler = SamplerWithoutReplacement(
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
logging.info("make_dataset")
|
||||
dataset = make_dataset(
|
||||
cfg,
|
||||
overwrite_sampler=sampler,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
overwrite_batch_size=1,
|
||||
overwrite_prefetch=12,
|
||||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
@@ -60,64 +49,49 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
logging.info(video_path)
|
||||
|
||||
|
||||
def render_dataset(dataset, out_dir, max_num_samples, fps):
|
||||
def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
out_dir = Path(out_dir)
|
||||
video_paths = []
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for i in range(max_num_samples):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = dataset.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
num_frames_left = dataset._sampler._sample_list.numel()
|
||||
episode_is_done = ep_idx != current_ep_idx
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
)
|
||||
dl_iter = iter(dataloader)
|
||||
|
||||
if episode_is_done:
|
||||
logging.info(f"Rendering episode {current_ep_idx}")
|
||||
num_episodes = len(dataset.data_ids_per_episode)
|
||||
for ep_id in range(min(max_num_episodes, num_episodes)):
|
||||
logging.info(f"Rendering episode {ep_id}")
|
||||
|
||||
for im_key in dataset.image_keys:
|
||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
||||
frames = {}
|
||||
for _ in dataset.data_ids_per_episode[ep_id]:
|
||||
item = next(dl_iter)
|
||||
|
||||
for im_key in dataset.image_keys:
|
||||
# when first frame of episode, initialize frames dict
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
if len(dataset.image_keys) > 0:
|
||||
im_name = im_key.replace("observation.images.", "")
|
||||
video_path = out_dir / f"episode_{ep_id}_{im_name}.mp4"
|
||||
else:
|
||||
# When episode has no more frame in its list of observation,
|
||||
# one frame still remains. It is the result of the last action taken.
|
||||
# It is stored in `"next"`, so we add it to the list of frames to render.
|
||||
frames[im_key].append(ep_td["next"][im_key])
|
||||
video_path = out_dir / f"episode_{ep_id}.mp4"
|
||||
video_paths.append(video_path)
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
if len(dataset.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
current_ep_idx = ep_idx
|
||||
|
||||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
if num_frames_left == 0:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
if current_ep_idx == NUM_EPISODES_TO_RENDER:
|
||||
break
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], dataset.fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
Reference in New Issue
Block a user