forked from tangger/lerobot
fix environment seeding
add fixes for reproducibility only try to start env if it is closed revision fix normalization and data type Improve README Improve README Tests are passing, Eval pretrained model works, Add gif Update gif Update gif Update gif Update gif Update README Update README update minor Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Update README.md Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Address suggestions Update thumbnail + stats Update thumbnail + stats Update README.md Co-authored-by: Alexander Soare <alexander.soare159@gmail.com> Add more comments Add test_examples.py
This commit is contained in:
@@ -25,7 +25,7 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
# Expects images in [0, 1].
|
||||
# Expects images in [0, 255].
|
||||
frames = torch.cat(frames)
|
||||
assert frames.max() <= 1 and frames.min() >= 0
|
||||
frames = (255 * frames).to(dtype=torch.uint8)
|
||||
@@ -47,44 +47,63 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(
|
||||
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||
cfg,
|
||||
overwrite_sampler=sampler,
|
||||
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
|
||||
normalize=False,
|
||||
overwrite_batch_size=1,
|
||||
overwrite_prefetch=12,
|
||||
)
|
||||
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
|
||||
for video_path in video_paths:
|
||||
logging.info(video_path)
|
||||
|
||||
|
||||
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
|
||||
out_dir = Path(out_dir)
|
||||
video_paths = []
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||
for i in range(max_num_samples):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = offline_buffer.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
num_frames_left = offline_buffer._sampler._sample_list.numel()
|
||||
episode_is_done = ep_idx != current_ep_idx
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
if episode_is_done:
|
||||
logging.info(f"Rendering episode {current_ep_idx}")
|
||||
|
||||
for im_key in offline_buffer.image_keys:
|
||||
if new_episode or no_more_frames:
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
|
||||
# when first frame of episode, initialize frames dict
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
else:
|
||||
# When episode has no more frame in its list of observation,
|
||||
# one frame still remains. It is the result of the last action taken.
|
||||
# It is stored in `"next"`, so we add it to the list of frames to render.
|
||||
frames[im_key].append(ep_td["next"][im_key])
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
|
||||
video_paths.append(str(video_path))
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], cfg.fps),
|
||||
args=(str(video_path), frames[im_key], fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
@@ -94,12 +113,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
# append current cameras images to list of frames
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
|
||||
if no_more_frames:
|
||||
if num_frames_left == 0:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
@@ -110,6 +124,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
thread.join()
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
return video_paths
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user