WIP

WIP train.py works, loss going down

WIP eval.py

Fix

WIP (eval running, TODO: verify results reproduced)

Eval works! (testing reproducibility)

WIP

pretrained model pusht reproduces same results as torchrl

pretrained model pusht reproduces same results as torchrl

Remove AbstractPolicy, Move all queues in select_action

WIP test_datasets passed (TODO: re-enable NormalizeTransform)
This commit is contained in:
Cadene
2024-03-31 15:05:25 +00:00
parent 920e0d118b
commit 1cdfbc8b52
17 changed files with 826 additions and 621 deletions

View File

@@ -10,7 +10,7 @@ from torchrl.data.replay_buffers import (
SamplerWithoutReplacement,
)
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import log_output_dir
from lerobot.common.utils import init_logging
@@ -44,8 +44,8 @@ def visualize_dataset(cfg: dict, out_dir=None):
shuffle=False,
)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(
logging.info("make_dataset")
dataset = make_dataset(
cfg,
overwrite_sampler=sampler,
# remove all transformations such as rescale images from [0,255] to [0,1] or normalization
@@ -55,12 +55,12 @@ def visualize_dataset(cfg: dict, out_dir=None):
)
logging.info("Start rendering episodes from offline buffer")
video_paths = render_dataset(offline_buffer, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
video_paths = render_dataset(dataset, out_dir, MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER, cfg.fps)
for video_path in video_paths:
logging.info(video_path)
def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
def render_dataset(dataset, out_dir, max_num_samples, fps):
out_dir = Path(out_dir)
video_paths = []
threads = []
@@ -69,17 +69,17 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
logging.info(f"Visualizing episode {current_ep_idx}")
for i in range(max_num_samples):
# TODO(rcadene): make it work with bsize > 1
ep_td = offline_buffer.sample(1)
ep_td = dataset.sample(1)
ep_idx = ep_td["episode"][FIRST_FRAME].item()
# TODO(rcadene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = offline_buffer._sampler._sample_list.numel()
# TODO(rcadene): modify dataset._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
num_frames_left = dataset._sampler._sample_list.numel()
episode_is_done = ep_idx != current_ep_idx
if episode_is_done:
logging.info(f"Rendering episode {current_ep_idx}")
for im_key in offline_buffer.image_keys:
for im_key in dataset.image_keys:
if not episode_is_done and num_frames_left > 0 and i < (max_num_samples - 1):
# when first frame of episode, initialize frames dict
if im_key not in frames:
@@ -93,7 +93,7 @@ def render_dataset(offline_buffer, out_dir, max_num_samples, fps):
frames[im_key].append(ep_td["next"][im_key])
out_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:
if len(dataset.image_keys) > 1:
camera = im_key[-1]
video_path = out_dir / f"episode_{current_ep_idx}_{camera}.mp4"
else: