fix environment seeding

add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
This commit is contained in:
Alexander Soare
2024-03-22 13:25:23 +00:00
committed by Cadene
parent 203bcd7ca5
commit 1a1308d62f
32 changed files with 686 additions and 282 deletions

View File

@@ -25,7 +25,7 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
# Expects images in [0, 1].
# Expects images in [0, 255].
frames = torch.cat(frames)
assert frames.max() <= 1 and frames.min() >= 0
frames = (255 * frames).to(dtype=torch.uint8)
@@ -47,44 +47,63 @@ def visualize_dataset(cfg: dict, out_dir=None):
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
normalize=False,
overwrite_batch_size=1,
overwrite_prefetch=12,
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
out_dir = Path(out_dir)
video_paths = []
threads = []
frames = {}
current_ep_idx = 0
logging.info(f"Visualizing episode {current_ep_idx}")
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = offline_buffer.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
new_episode = ep_idx != current_ep_idx
num_frames_left = offline_buffer._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
if new_episode:
logging.info(f"Visualizing episode {current_ep_idx}")
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
for im_key in offline_buffer.image_keys:
if new_episode or no_more_frames:
# append last observed frames (the ones after last action taken)
frames[im_key].append(offline_buffer.transform(ep_td["next"])[im_key])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir.mkdir(parents=True, exist_ok=True)
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
# when first frame of episode, initialize frames dict
if im_key not in frames:
frames[im_key] = []
# add current frame to list of frames to render
frames[im_key].append(ep_td[im_key])
else:
# When episode has no more frame in its list of observation,
# one frame still remains. It is the result of the last action taken.
# It is stored in `"next"`, so we add it to the list of frames to render.
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
camera = im_key[-1]
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else:
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
video_path = out_dir / f"episode_{current_ep_idx}.mp4"
video_paths.append(str(video_path))
thread = threading.Thread(
target=cat_and_write_video,
args=(str(video_path), frames[im_key], cfg.fps),
args=(str(video_path), frames[im_key], fps),
)
thread.start()
threads.append(thread)
@@ -94,12 +113,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
# reset list of frames
del frames[im_key]
# append current cameras images to list of frames
if im_key not in frames:
frames[im_key] = []
frames[im_key].append(ep_td[im_key])
if no_more_frames:
if num_frames_left == 0:
logging.info("Ran out of frames")
break
@@ -110,6 +124,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
thread.join()
logging.info("End of visualize_dataset")
return video_paths
if __name__ == "__main__":