forked from tangger/lerobot
Improve visualize_dataset, Improve AbstractReplayBuffer, Small improvements
This commit is contained in:
@@ -9,13 +9,13 @@ import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.envs import EnvBase
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import set_seed
|
||||
from lerobot.common.utils import init_logging, set_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -109,10 +109,18 @@ def eval(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
init_logging()
|
||||
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_seed(cfg.seed)
|
||||
print(colored("Log dir:", "yellow", attrs=["bold"]), out_dir)
|
||||
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
@@ -142,6 +150,8 @@ def eval(cfg: dict, out_dir=None):
|
||||
)
|
||||
print(metrics)
|
||||
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
eval_cli()
|
||||
|
||||
@@ -4,13 +4,12 @@ import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict.nn import TensorDictModule
|
||||
from termcolor import colored
|
||||
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers import PrioritizedSliceSampler
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
@@ -164,7 +163,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(out_dir, job_name, cfg)
|
||||
|
||||
logging.info(colored("Work dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.offline_steps=} ({format_big_number(cfg.offline_steps)})")
|
||||
logging.info(f"{cfg.online_steps=}")
|
||||
@@ -212,7 +211,6 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for env_step in range(cfg.online_steps):
|
||||
if env_step == 0:
|
||||
logging.info("Start online training by interacting with environment")
|
||||
# TODO: use SyncDataCollector for that?
|
||||
# TODO: add configurable number of rollout? (default=1)
|
||||
with torch.no_grad():
|
||||
rollout = env.rollout(
|
||||
@@ -268,6 +266,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
step += 1
|
||||
online_step += 1
|
||||
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import torch
|
||||
from torchrl.data.replay_buffers import SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers import (
|
||||
SamplerWithoutReplacement,
|
||||
)
|
||||
|
||||
from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.utils import init_logging
|
||||
|
||||
NUM_EPISODES_TO_RENDER = 10
|
||||
NUM_EPISODES_TO_RENDER = 50
|
||||
MAX_NUM_STEPS = 1000
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -17,45 +24,88 @@ def visualize_dataset_cli(cfg: dict):
|
||||
visualize_dataset(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
|
||||
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=1,
|
||||
strict_length=False,
|
||||
init_logging()
|
||||
log_output_dir(out_dir)
|
||||
|
||||
# we expect frames of each episode to be stored next to each others sequentially
|
||||
sampler = SamplerWithoutReplacement(
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
offline_buffer = make_offline_buffer(cfg, sampler)
|
||||
logging.info("make_offline_buffer")
|
||||
offline_buffer = make_offline_buffer(
|
||||
cfg, overwrite_sampler=sampler, normalize=False, overwrite_batch_size=1, overwrite_prefetch=12
|
||||
)
|
||||
|
||||
for _ in range(NUM_EPISODES_TO_RENDER):
|
||||
episode = offline_buffer.sample(MAX_NUM_STEPS)
|
||||
logging.info("Start rendering episodes from offline buffer")
|
||||
|
||||
ep_idx = episode["episode"][FIRST_FRAME].item()
|
||||
ep_frames = torch.cat(
|
||||
[
|
||||
episode["observation"]["image"][FIRST_FRAME][None, ...],
|
||||
episode["next", "observation"]["image"],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
threads = []
|
||||
frames = {}
|
||||
current_ep_idx = 0
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
for _ in range(MAX_NUM_STEPS * NUM_EPISODES_TO_RENDER):
|
||||
# TODO(rcadene): make it work with bsize > 1
|
||||
ep_td = offline_buffer.sample(1)
|
||||
ep_idx = ep_td["episode"][FIRST_FRAME].item()
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
# TODO(rcadene): make fps configurable
|
||||
video_path = video_dir / f"episode_{ep_idx}.mp4"
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
|
||||
assert ep_frames.min().item() >= 0
|
||||
assert ep_frames.max().item() > 1, "Not mendatory, but sanity check"
|
||||
assert ep_frames.max().item() <= 255
|
||||
ep_frames = ep_frames.type(torch.uint8)
|
||||
imageio.mimsave(video_path, ep_frames.numpy().transpose(0, 2, 3, 1), fps=cfg.fps)
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
|
||||
# ran out of episodes
|
||||
if offline_buffer._sampler._sample_list.numel() == 0:
|
||||
for im_key in offline_buffer.image_keys:
|
||||
if new_episode or no_more_frames:
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
camera = im_key[-1]
|
||||
video_path = video_dir / f"episode_{current_ep_idx}_{camera}.mp4"
|
||||
else:
|
||||
video_path = video_dir / f"episode_{current_ep_idx}.mp4"
|
||||
|
||||
thread = threading.Thread(
|
||||
target=cat_and_write_video,
|
||||
args=(str(video_path), frames[im_key], cfg.fps),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
|
||||
current_ep_idx = ep_idx
|
||||
|
||||
# reset list of frames
|
||||
del frames[im_key]
|
||||
|
||||
# append current cameras images to list of frames
|
||||
if im_key not in frames:
|
||||
frames[im_key] = []
|
||||
frames[im_key].append(ep_td[im_key])
|
||||
|
||||
if no_more_frames:
|
||||
logging.info("Ran out of frames")
|
||||
break
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
logging.info("End of visualize_dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
visualize_dataset_cli()
|
||||
|
||||
Reference in New Issue
Block a user