Compare commits

...

6 Commits

Author SHA1 Message Date
Cadene
9cdc24bc0e WIP 2024-03-19 13:41:49 +00:00
Cadene
a346469a5a Update time to monotonic 2024-03-18 16:26:07 +00:00
Cadene
2bef00c317 Add video decoding in dataset (WIP: issue with gray background) 2024-03-18 16:25:33 +00:00
Cadene
9954994a4b Add video decoding in dataset (WIP: issue with gray background) 2024-03-18 16:24:32 +00:00
Cadene
0fc94b81b3 Add video decoding in dataset (WIP: issue with gray background) 2024-03-18 16:24:05 +00:00
Cadene
d32a279435 Add test.py for gpu decoder
Works with SliceSampler

WIP

Add video_id_to_path as a meta_data

plot frame
2024-03-17 19:32:10 +00:00
14 changed files with 953 additions and 47 deletions

View File

@@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
from lerobot.common.datasets.transforms import DecodeVideoTransform
class AbstractExperienceReplay(TensorDictReplayBuffer):
def __init__(
@@ -33,7 +35,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
self.dataset_id = dataset_id
self.shuffle = shuffle
self.root = root
storage = self._download_or_load_dataset()
storage, meta_data = self._download_or_load_dataset()
if transform is not None and "video_id_to_path" in meta_data:
# hack to access video paths
assert isinstance(transform, Compose)
for tf in transform:
if isinstance(tf, DecodeVideoTransform):
tf.set_video_id_to_path(meta_data["video_id_to_path"])
super().__init__(
storage=storage,
@@ -99,7 +108,13 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
else:
self.data_dir = self.root / self.dataset_id
return TensorStorage(TensorDict.load_memmap(self.data_dir))
storage = TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
# required to not send cuda frames to cpu by default
storage._storage.clear_device_()
meta_data = torch.load(self.data_dir / "meta_data.pth")
return storage, meta_data
def _compute_stats(self, num_batch=100, batch_size=32):
rb = TensorDictReplayBuffer(

View File

@@ -87,17 +87,30 @@ def make_offline_buffer(
prefetch=prefetch if isinstance(prefetch, int) else None,
)
if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys
transforms = []
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
# transforms = [
# ViewSliceHorizonTransform(num_slices, cfg.policy.horizon),
# KeepFrames(positions=[0], in_keys=[("observation")]),
# DecodeVideoTransform(
# data_dir=offline_buffer.data_dir,
# device=cfg.device,
# frame_rate=None,
# in_keys=[("observation", "frame")],
# out_keys=[("observation", "frame", "data")],
# ),
# ]
if normalize:
if cfg.policy.name == "tdmpc":
img_keys = []
for key in offline_buffer.image_keys:
img_keys.append(("next", *key))
img_keys += offline_buffer.image_keys
else:
img_keys = offline_buffer.image_keys
transforms.append(Prod(in_keys=img_keys, prod=1 / 255))
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
stats = offline_buffer.compute_or_load_stats()

View File

@@ -0,0 +1,310 @@
from pathlib import Path
from typing import Sequence
import einops
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
from tensordict.utils import NestedKey
from torchaudio.io import StreamReader
from torchrl.envs.transforms import Transform
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
class ViewSliceHorizonTransform(Transform):
invertible = False
def __init__(self, num_slices, horizon):
super().__init__()
self.num_slices = num_slices
self.horizon = horizon
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
td = td.view(self.num_slices, self.horizon)
return td
class KeepFrames(Transform):
invertible = False
def __init__(
self,
positions,
in_keys: Sequence[NestedKey],
out_keys: Sequence[NestedKey] = None,
):
if isinstance(positions, list):
assert isinstance(positions[0], int)
# TODO(rcadene)L add support for `isinstance(positions, int)`?
self.positions = positions
if out_keys is None:
out_keys = in_keys
super().__init__(in_keys=in_keys, out_keys=out_keys)
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
# we need set batch_size=[] before assigning a different shape to td[outkey]
td.batch_size = []
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
td[outkey] = td[inkey][:, self.positions]
return td
class DecodeVideoTransform(Transform):
invertible = False
def __init__(
self,
data_dir: Path | str,
device="cpu",
decoding_lib: str = "torchaudio",
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
format: str | None = None,
frame_rate: int | None = None,
width: int | None = None,
height: int | None = None,
in_keys: Sequence[NestedKey] = None,
out_keys: Sequence[NestedKey] = None,
in_keys_inv: Sequence[NestedKey] | None = None,
out_keys_inv: Sequence[NestedKey] | None = None,
):
self.data_dir = Path(data_dir)
self.device = device
self.decoding_lib = decoding_lib
self.format = format
self.frame_rate = frame_rate
self.width = width
self.height = height
self.video_id_to_path = None
if out_keys is None:
out_keys = in_keys
if in_keys_inv is None:
in_keys_inv = out_keys
if out_keys_inv is None:
out_keys_inv = in_keys
super().__init__(
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
)
def set_video_id_to_path(self, video_id_to_path):
self.video_id_to_path = video_id_to_path
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# _reset is called once when the environment reset to normalize the first observation
tensordict_reset = self._call(tensordict_reset)
return tensordict_reset
@dispatch(source="in_keys", dest="out_keys")
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
return self._call(tensordict)
def _call(self, td: TensorDictBase) -> TensorDictBase:
assert (
self.video_id_to_path is not None
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
# TODO(rcadene): don't know how to do `inkey not in td`
if td.get(inkey, None) is None:
continue
bsize = len(td[inkey]) # num episodes in the batch
b_frames = []
for i in range(bsize):
assert (
td["observation", "frame", "video_id"].ndim == 3
and td["observation", "frame", "video_id"].shape[2] == 1
), "We expect 2 dims. Respectively, number of episodes in the batch, number of observations, 1"
ep_video_ids = td[inkey]["video_id"][i]
timestamps = td[inkey]["timestamp"][i]
frame_ids = td["frame_id"][i]
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
assert unique_video_id
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
assert is_ascending
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
assert is_contiguous
FIRST_FRAME = 0 # noqa: N806
video_id = ep_video_ids[FIRST_FRAME].squeeze(0).item()
video_path = self.data_dir / self.video_id_to_path[video_id]
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
num_contiguous_frames = len(timestamps)
if self.decoding_lib == "torchaudio":
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "ffmpegio":
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "decord":
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
else:
raise ValueError(self.decoding_lib)
assert frames.ndim == 4
assert frames.shape[1] == 3
b_frames.append(frames)
td[outkey] = torch.stack(b_frames)
if self.device == "cuda":
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
return td
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
filter_desc = []
video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames,
"buffer_chunk_size": num_contiguous_frames,
}
# choice of decoder
if self.device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
else:
video_stream_kwgs["decoder"] = "h264"
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
# resize
resize_width = self.width is not None
resize_height = self.height is not None
if resize_width or resize_height:
if self.device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
else:
scales = []
if resize_width:
scales.append(f"width={self.width}")
if resize_height:
scales.append(f"height={self.height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if self.format is not None:
if self.device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
raise NotImplementedError()
# filter_desc = f"scale=format={self.format}"
# filter_desc = f"scale_cuda=format={self.format}"
# filter_desc = f"scale_npp=format={self.format}"
else:
filter_desc.append(f"format=pix_fmts={self.format}")
# choice of frame rate
if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}")
filter_desc.append("scale=in_range=limited:out_range=full")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
s = StreamReader(str(video_path))
s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
if "yuv" in self.format:
frames = yuv_to_rgb(frames)
return frames
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
import ffmpegio
fs, frames = ffmpegio.video.read(
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
)
frames = torch.from_numpy(frames)
frames = einops.rearrange(frames, "b h w c -> b c h w")
if self.device == "cuda":
frames = frames.to(self.device)
return frames
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
from decord import VideoReader, cpu, gpu
with open(str(video_path), "rb") as f:
ctx = gpu if self.device == "cuda" else cpu
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
# frame_id = frame_ids[0].item()
# frames = vr.get_batch([frame_id])
# frames = torch.from_numpy(frames.asnumpy())
# frames = einops.rearrange(frames, "b h w c -> b c h w")
# return frames

View File

@@ -3,6 +3,7 @@ import zipfile
from pathlib import Path
import requests
import torch
import tqdm
@@ -28,3 +29,26 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return True
else:
return False
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.14 * v
g = y + -0.396 * u - 0.581 * v
b = y + 2.029 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb

View File

@@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
def update(self, replay_buffer, step):
del step
start_time = time.time()
start_time = time.monotonic()
self.train()
@@ -104,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
data_s = time.monotonic() - start_time
loss = self.compute_loss(batch)
loss.backward()
@@ -125,7 +125,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
# "lr": self.lr_scheduler.get_last_lr()[0],
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
"update_s": time.monotonic() - start_time,
}
return info

View File

@@ -188,8 +188,8 @@ class MetricLogger:
def log_every(self, iterable, print_freq, header=None):
if not header:
header = ""
start_time = time.time()
end = time.time()
start_time = time.monotonic()
end = time.monotonic()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
@@ -218,9 +218,9 @@ class MetricLogger:
)
mega_b = 1024.0 * 1024.0
for i, obj in enumerate(iterable):
data_time.update(time.time() - end)
data_time.update(time.monotonic() - end)
yield obj
iter_time.update(time.time() - end)
iter_time.update(time.monotonic() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
@@ -247,8 +247,8 @@ class MetricLogger:
data=str(data_time),
)
)
end = time.time()
total_time = time.time() - start_time
end = time.monotonic()
total_time = time.monotonic() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))

View File

@@ -111,7 +111,7 @@ class DiffusionPolicy(nn.Module):
return action
def update(self, replay_buffer, step):
start_time = time.time()
start_time = time.monotonic()
self.diffusion.train()
@@ -158,7 +158,7 @@ class DiffusionPolicy(nn.Module):
batch = replay_buffer.sample(batch_size)
batch = process_batch(batch, self.cfg.horizon, num_slices)
data_s = time.time() - start_time
data_s = time.monotonic() - start_time
loss = self.diffusion.compute_loss(batch)
loss.backward()
@@ -181,7 +181,7 @@ class DiffusionPolicy(nn.Module):
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"data_s": data_s,
"update_s": time.time() - start_time,
"update_s": time.monotonic() - start_time,
}
# TODO(rcadene): remove hardcoding

View File

@@ -291,7 +291,7 @@ class TDMPC(nn.Module):
def update(self, replay_buffer, step, demo_buffer=None):
"""Main update function. Corresponds to one iteration of the model learning."""
start_time = time.time()
start_time = time.monotonic()
num_slices = self.cfg.batch_size
batch_size = self.cfg.horizon * num_slices
@@ -405,7 +405,7 @@ class TDMPC(nn.Module):
self.std = h.linear_schedule(self.cfg.std_schedule, step)
self.model.train()
data_s = time.time() - start_time
data_s = time.monotonic() - start_time
# Compute targets
with torch.no_grad():
@@ -501,7 +501,7 @@ class TDMPC(nn.Module):
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"data_s": data_s,
"update_s": time.time() - start_time,
"update_s": time.monotonic() - start_time,
}
info["demo_batch_size"] = demo_batch_size
info["expectile"] = expectile

View File

@@ -0,0 +1,129 @@
"""
usage: `python lerobot/scripts/convert_dataset_uint8_to_mp4.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
"""
import argparse
import shutil
from pathlib import Path
import torch
from tensordict import TensorDict
def convert_dataset_uint8_to_mp4(in_data_dir, out_data_dir, fps, overwrite_num_frames=None):
assert fps is not None and isinstance(fps, float)
# load full dataset as a tensor dict
in_td_data = TensorDict.load_memmap(in_data_dir)
out_data_dir = Path(out_data_dir)
# use 1 frame to know the specification of the dataset
# and copy it over `n` frames in the test artifact directory
out_rb_dir = out_data_dir / "replay_buffer"
if out_rb_dir.exists():
shutil.rmtree(out_rb_dir)
num_frames = len(in_td_data) if overwrite_num_frames is None else overwrite_num_frames
# del in_td_data["observation", "image"]
# del in_td_data["next", "observation", "image"]
out_td_data = in_td_data[0].memmap_().clone()
out_td_data["observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
out_td_data["observation", "frame", "timestamp"] = torch.zeros(1)
out_td_data["next", "observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
out_td_data["next", "observation", "frame", "timestamp"] = torch.zeros(1)
out_td_data = out_td_data.expand(num_frames)
out_td_data = out_td_data.memmap_like(out_rb_dir)
out_vid_dir = out_data_dir / "videos"
out_vid_dir.mkdir(parents=True, exist_ok=True)
video_id_to_path = {}
for key in out_td_data.keys(include_nested=True, leaves_only=True):
if in_td_data.get(key, None) is None:
continue
if overwrite_num_frames is None:
out_td_data[key].copy_(in_td_data[key].clone())
else:
out_td_data[key][:num_frames].copy_(in_td_data[key][:num_frames].clone())
for i in range(num_frames):
video_id = in_td_data["episode"][i]
frame_id = in_td_data["frame_id"][i]
out_td_data["observation", "frame", "video_id"][i] = video_id
out_td_data["observation", "frame", "timestamp"][i] = frame_id / fps
out_td_data["next", "observation", "frame", "video_id"][i] = video_id
out_td_data["next", "observation", "frame", "timestamp"][i] = (frame_id + 1) / fps
video_id = video_id.item()
if video_id not in video_id_to_path:
video_id_to_path[video_id] = f"videos/episode_{video_id}.mp4"
# copy the first `n` frames so that we have real data
# make sure everything has been properly written
out_td_data.lock_()
# copy the full statistics of dataset since it's pretty small
in_stats_path = Path(in_data_dir) / "stats.pth"
out_stats_path = Path(out_data_dir) / "stats.pth"
shutil.copy(in_stats_path, out_stats_path)
meta_data = {
"video_id_to_path": video_id_to_path,
}
torch.save(meta_data, out_data_dir / "meta_data.pth")
# def write_to_mp4():
# buffer = io.BytesIO()
# swriter = StreamWriter(buffer, format="mp4")
# device = "cuda"
# c,h,w = in_td_data[0]["observation", "image"].shape
# swriter.add_video_stream(
# frame_rate=fps,
# width=w,
# height=h,
# # frame_rate=30000 / 1001,
# format="yuv444p",
# encoder="h264_nvenc",
# encoder_format="yuv444p",
# hw_accel=device,
# )
# for i in range(num_frames):
# ep_id = in_td_data[i]["episode"]
# data = in_td_data[i]["observation", "image"]
# with swriter.open():
# t0 = time.monotonic()
# data = data.to(device)
# swriter.write_video_chunk(0, data)
# elapsed = time.monotonic() - t0
# size = buffer.tell()
# print(f"{elapsed=}")
# print(f"{size=}")
# buffer.seek(0)
# video = buffer.read()
# vid_path = out_vid_dir / f"episode_{ep_id}.mp4"
# with open(vid_path, 'wb+') as f:
# f.write(video)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Create dataset")
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
parser.add_argument("--fps", type=float)
args = parser.parse_args()
convert_dataset_uint8_to_mp4(args.in_data_dir, args.out_data_dir, args.fps)

View File

@@ -32,7 +32,7 @@ def eval_policy(
fps: int = 15,
return_first_video: bool = False,
):
start = time.time()
start = time.monotonic()
sum_rewards = []
max_rewards = []
successes = []
@@ -85,8 +85,8 @@ def eval_policy(
"avg_sum_reward": np.nanmean(sum_rewards),
"avg_max_reward": np.nanmean(max_rewards),
"pc_success": np.nanmean(successes) * 100,
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
"eval_s": time.monotonic() - start,
"eval_ep_s": (time.monotonic() - start) / num_episodes,
}
if return_first_video:
return info, first_video

View File

@@ -1,5 +1,8 @@
import logging
import shutil
import subprocess
import threading
import time
from pathlib import Path
import einops
@@ -26,9 +29,49 @@ def visualize_dataset_cli(cfg: dict):
def cat_and_write_video(video_path, frames, fps):
frames = torch.cat(frames)
assert frames.dtype == torch.uint8
if frames.dtype != torch.uint8:
logging.warning(f"frames are expected to be uint8 to {frames.dtype}")
frames = frames.type(torch.uint8)
_, _, h, w = frames.shape
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
imageio.mimsave(video_path, frames, fps=fps)
img_dir = Path(video_path.split(".")[0])
if img_dir.exists():
shutil.rmtree(img_dir)
img_dir.mkdir(parents=True, exist_ok=True)
for i in range(len(frames)):
imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i])
ffmpeg_command = [
"ffmpeg",
"-r",
str(fps),
"-f",
"image2",
"-s",
f"{w}x{h}",
"-i",
str(img_dir / "frame_%04d.png"),
"-vcodec",
"libx264",
#'-vcodec', 'libx265',
#'-vcodec', 'libaom-av1',
"-crf",
"0", # Lossless option
"-pix_fmt",
# "yuv420p", # Specify pixel format
"yuv444p", # Specify pixel format
video_path,
# video_path.replace(".mp4", ".mkv")
]
subprocess.run(ffmpeg_command, check=True)
time.sleep(0.1)
# clean temporary image directory
# shutil.rmtree(img_dir)
def visualize_dataset(cfg: dict, out_dir=None):
@@ -61,7 +104,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
new_episode = ep_idx != current_ep_idx
new_episode = ep_idx > current_ep_idx
if ep_idx < current_ep_idx:
break
if new_episode:
logging.info(f"Visualizing episode {current_ep_idx}")
@@ -71,7 +117,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
# append last observed frames (the ones after last action taken)
frames[im_key].append(ep_td[("next", *im_key)])
video_dir = Path(out_dir) / "visualize_dataset"
video_dir = Path(out_dir) / "videos"
video_dir.mkdir(parents=True, exist_ok=True)
if len(offline_buffer.image_keys) > 1:

68
poetry.lock generated
View File

@@ -2684,13 +2684,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
[[package]]
name = "sentry-sdk"
version = "1.41.0"
version = "1.42.0"
description = "Python client for Sentry (https://sentry.io)"
optional = false
python-versions = "*"
files = [
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
{file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"},
{file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"},
]
[package.dependencies]
@@ -2714,6 +2714,7 @@ grpcio = ["grpcio (>=1.21.1)"]
httpx = ["httpx (>=0.16.0)"]
huey = ["huey (>=2)"]
loguru = ["loguru (>=0.5)"]
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
pure-eval = ["asttokens", "executing", "pure-eval"]
@@ -2829,18 +2830,18 @@ test = ["pytest"]
[[package]]
name = "setuptools"
version = "69.1.1"
version = "69.2.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
files = [
{file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"},
{file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"},
{file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"},
{file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
@@ -2949,7 +2950,7 @@ mpmath = ">=0.19"
[[package]]
name = "tensordict"
version = "0.4.0+551331d"
version = "0.4.0+6a56ecd"
description = ""
optional = false
python-versions = "*"
@@ -2970,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
type = "git"
url = "https://github.com/pytorch/tensordict"
reference = "HEAD"
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
resolved_reference = "6a56ecd728757feee387f946b7da66dd452b739b"
[[package]]
name = "termcolor"
@@ -3072,6 +3073,43 @@ typing-extensions = ">=4.8.0"
opt-einsum = ["opt-einsum (>=3.3)"]
optree = ["optree (>=0.9.1)"]
[[package]]
name = "torchaudio"
version = "2.2.1"
description = "An audio package for PyTorch"
optional = false
python-versions = "*"
files = [
{file = "torchaudio-2.2.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:580eefd764a01a64d5b6aa260c0c47974be6a6964892d54029a73b17f4611fcd"},
{file = "torchaudio-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ad55c2069b27bbe18e14783a202e3f3f8082fe9e59281436ba797edb0fc94d5"},
{file = "torchaudio-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:55d23986254f7af689695f3fc214c4aa3e73dc931289ecdba7262d73fea7af7a"},
{file = "torchaudio-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b916b7764698ba9319aa3b25519139892de8665d84438969bac5e1d8578c6a11"},
{file = "torchaudio-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:281cd4bdb9e65c0618a028b809df9e06f9bd9592aeef8f2b37b4d8a788ce5f2b"},
{file = "torchaudio-2.2.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:274cb8474bc1e56b768ef347d3188661c5a9d5e68e2df56fc0aff11cc73c916a"},
{file = "torchaudio-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e62c27b17672cc2bdd9663681e533000f9c0984e6a0f3d455f7051bc005bb02"},
{file = "torchaudio-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7df7d5d9100116be38ff7b27b628820dca4a9e3fe79394605141d339e3b3e46d"},
{file = "torchaudio-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:20b2965db4f843021636f53d3fab1075c3f8959c450c647629124d24c7e6cbb0"},
{file = "torchaudio-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:63dd0e840bcf2e4aceb7a98daccfaf7a2a5b3a927647b98bbef449b0b190f2cc"},
{file = "torchaudio-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2c232dc8bee97d303b90833ba934d8905eb7326456236efcd9fa71ccb92fd363"},
{file = "torchaudio-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2419387cf04d33047369337bf09c00c2a7673a8f52f80258454c7eca7d205d23"},
{file = "torchaudio-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2483c0620a68a136359ae90c893608ad5cd73091fb0351b94d33af126a0e3d67"},
{file = "torchaudio-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bd389f33b7dbfc44e5f4070fc6db00cc560992bea8378a952889acfd772b7022"},
{file = "torchaudio-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:d5af725a327b79f3bd8389c53ec51554ee003c18434fc47e68da49b09900132e"},
{file = "torchaudio-2.2.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:81ef88d7693e3b99007d1ee742fd81b9a92399ecbf88eb7ed69949443005ffba"},
{file = "torchaudio-2.2.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f487a7d3177ae6af016750850ee93788e880218a1a310bc6c76901e212f91cd3"},
{file = "torchaudio-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bee5478ec2cb7d0eaa97023d817aa4914010e1ab0c266f64ef1b0db893aceb49"},
{file = "torchaudio-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a4462b3f214f60b6b8f78e12a4cf1291c9bc353deed709ac3dfdedbed513a7a3"},
{file = "torchaudio-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4bc43d11d9e086f0dfb29f6ea99517d8ec06fa80d97283f2c8b83c4cd467dd1a"},
{file = "torchaudio-2.2.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0339fe78ed9c29f704296761b28bb055b5350625ff503ad781704397934e6b58"},
{file = "torchaudio-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68b1d9f8ffe9b26ef04e80d82ae2dc2f74b1a1eb64c3e8ad21b525802b3bc7ac"},
{file = "torchaudio-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3962fea5d2511c9ab2b1dd515b45ec44d0c28e51f3b05c0b9fa7bbcc3c213bc1"},
{file = "torchaudio-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:cb2da08abb7b68dc7b0105748b1a736dd33329f841374013ec02c54e04bedf29"},
{file = "torchaudio-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:54996977ab1c875729e8dedc4695609ca58f876c23756c79979c6b50136b3385"},
]
[package.dependencies]
torch = "2.2.1"
[[package]]
name = "torchrl"
version = "0.4.0+13bef42"
@@ -3311,20 +3349,20 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
[[package]]
name = "zipp"
version = "3.17.0"
version = "3.18.0"
description = "Backport of pathlib-compatible object wrapper for zip files"
optional = false
python-versions = ">=3.8"
files = [
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
{file = "zipp-3.18.0-py3-none-any.whl", hash = "sha256:c1bb803ed69d2cce2373152797064f7e79bc43f0a3748eb494096a867e0ebf79"},
{file = "zipp-3.18.0.tar.gz", hash = "sha256:df8d042b02765029a09b157efd8e820451045890acc30f8e37dd2f94a060221f"},
]
[package.extras]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9"
content-hash = "e0c9fa6894aaa917493f81028c1bcc3fff8c56d9025681af44534fc3dbe7646e"

View File

@@ -51,6 +51,7 @@ torchvision = "^0.17.1"
h5py = "^3.10.0"
dm-control = "1.0.14"
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
torchaudio = "^2.2.1"
[tool.poetry.group.dev.dependencies]

330
test.py Normal file
View File

@@ -0,0 +1,330 @@
# TODO(rcadene): add tests
# TODO(rcadene): what is the best format to store/load videos?
import subprocess
from collections.abc import Callable
from pathlib import Path
import einops
import torch
import torchaudio
import torchrl
from matplotlib import pyplot as plt
from tensordict import TensorDict
from torchaudio.utils import ffmpeg_utils
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
from torchrl.envs.transforms.transforms import Compose
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
from lerobot.common.utils import set_seed
NUM_STATE_CHANNELS = 12
NUM_ACTION_CHANNELS = 12
def count_frames(video_path):
try:
# Construct the ffprobe command to get the number of frames
cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=nb_frames",
"-of",
"default=nokey=1:noprint_wrappers=1",
video_path,
]
# Execute the ffprobe command and capture the output
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Convert the output to an integer
num_frames = int(result.stdout.strip())
return num_frames
except Exception as e:
print(f"An error occurred: {e}")
return -1
def get_frame_rate(video_path):
try:
cmd = [
"ffprobe",
"-v",
"error",
"-select_streams",
"v:0",
"-show_entries",
"stream=r_frame_rate",
"-of",
"default=nokey=1:noprint_wrappers=1",
video_path,
]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# The frame rate is typically represented as a fraction (e.g., "30000/1001").
# To convert it to a float, we can evaluate the fraction.
frame_rate = eval(result.stdout.strip())
return frame_rate
except Exception as e:
print(f"An error occurred: {e}")
return -1
def get_frame_timestamps(frame_rate, num_frames):
timestamps = [(1 / frame_rate) * i for i in range(num_frames)]
return timestamps
# class ClearDeviceTransform(Transform):
# invertible = False
# def __init__(self):
# super().__init__()
# def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
# # _reset is called once when the environment reset to normalize the first observation
# tensordict_reset = self._call(tensordict_reset)
# return tensordict_reset
# @dispatch(source="in_keys", dest="out_keys")
# def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# return self._call(tensordict)
# def _call(self, td: TensorDictBase) -> TensorDictBase:
# td.clear_device_()
# return td
class VideoExperienceReplay(TensorDictReplayBuffer):
def __init__(
self,
batch_size: int = None,
*,
root: Path = None,
pin_memory: bool = False,
prefetch: int = None,
sampler: SliceSampler = None,
collate_fn: Callable = None,
writer: Writer = None,
transform: "torchrl.envs.Transform" = None,
):
self.data_dir = root
self.rb_dir = self.data_dir / "replay_buffer"
storage, meta_data = self._load_or_download()
# hack to access video paths
assert isinstance(transform, Compose)
for tf in transform:
if isinstance(tf, DecodeVideoTransform):
tf.set_video_id_to_path(meta_data["video_id_to_path"])
super().__init__(
storage=storage,
sampler=sampler,
writer=ImmutableDatasetWriter() if writer is None else writer,
collate_fn=_collate_id if collate_fn is None else collate_fn,
pin_memory=pin_memory,
prefetch=prefetch,
batch_size=batch_size,
transform=transform,
)
def _load_or_download(self, force_download=False):
if not force_download and self.data_dir.exists():
storage = TensorStorage(TensorDict.load_memmap(self.rb_dir))
meta_data = torch.load(self.data_dir / "meta_data.pth")
else:
storage, meta_data = self._download()
torch.save(meta_data, self.data_dir / "meta_data.pth")
# required to not send cuda frames to cpu by default
storage._storage.clear_device_()
return storage, meta_data
def _download(self):
num_episodes = 1
video_id_to_path = {}
for episode_id in range(num_episodes):
video_path = torchaudio.utils.download_asset(
"tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4"
)
# several episodes can belong to the same video
video_id = episode_id
video_id_to_path[video_id] = video_path
print(f"{video_path=}")
num_frames = count_frames(video_path)
print(f"{num_frames=}")
frame_rate = get_frame_rate(video_path)
print(f"{frame_rate=}")
frame_timestamps = get_frame_timestamps(frame_rate, num_frames)
reward = torch.zeros(num_frames, 1, dtype=torch.float32)
success = torch.zeros(num_frames, 1, dtype=torch.bool)
done = torch.zeros(num_frames, 1, dtype=torch.bool)
state = torch.randn(num_frames, NUM_STATE_CHANNELS, dtype=torch.float32)
action = torch.randn(num_frames, NUM_ACTION_CHANNELS, dtype=torch.float32)
timestamp = torch.tensor(frame_timestamps)
frame_id = torch.arange(0, num_frames, 1)
episode_id_tensor = torch.tensor([episode_id] * num_frames, dtype=torch.int)
video_id_tensor = torch.tensor([video_id] * num_frames, dtype=torch.int)
# last step of demonstration is considered done
done[-1] = True
ep_td = TensorDict(
{
("observation", "frame", "video_id"): video_id_tensor[:-1],
("observation", "frame", "timestamp"): timestamp[:-1],
("observation", "state"): state[:-1],
"action": action[:-1],
"episode": episode_id_tensor[:-1],
"frame_id": frame_id[:-1],
("next", "observation", "frame", "video_id"): video_id_tensor[1:],
("next", "observation", "frame", "timestamp"): timestamp[1:],
("next", "observation", "state"): state[1:],
("next", "reward"): reward[1:],
("next", "done"): done[1:],
("next", "success"): success[1:],
},
batch_size=num_frames - 1,
)
# TODO:
total_frames = num_frames - 1
if episode_id == 0:
# hack to initialize tensordict data structure to store episodes
td_data = ep_td[0].expand(total_frames).memmap_like(self.rb_dir)
td_data[:] = ep_td
meta_data = {
"video_id_to_path": video_id_to_path,
}
return TensorStorage(td_data.lock_()), meta_data
if __name__ == "__main__":
import time
import tqdm
print("FFmpeg Library versions:")
for k, ver in ffmpeg_utils.get_versions().items():
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
print("Available NVDEC Decoders:")
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
if "cuvid" in k:
print(f" - {k}")
def create_replay_buffer(device, format=None):
data_dir = Path("tmp/2024_03_17_data_video/pusht")
num_slices = 1
horizon = 2
batch_size = num_slices * horizon
sampler = SliceSamplerWithoutReplacement(
num_slices=num_slices,
strict_length=True,
shuffle=False,
)
transforms = [
# ClearDeviceTransform(),
ViewSliceHorizonTransform(num_slices, horizon),
KeepFrames(positions=[0], in_keys=[("observation")]),
DecodeVideoTransform(
data_dir=data_dir,
device=device,
frame_rate=None,
format=format,
in_keys=[("observation", "frame")],
out_keys=[("observation", "frame", "data")],
),
]
replay_buffer = VideoExperienceReplay(
root=data_dir,
batch_size=batch_size,
# prefetch=4,
transform=Compose(*transforms),
sampler=sampler,
)
return replay_buffer
def test_time():
replay_buffer = create_replay_buffer(device="cuda")
start = time.monotonic()
for _ in tqdm.tqdm(range(2)):
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
replay_buffer.sample(include_info=False)
torch.cuda.synchronize()
print(time.monotonic() - start)
start = time.monotonic()
for _ in tqdm.tqdm(range(10)):
replay_buffer.sample(include_info=False)
torch.cuda.synchronize()
print(time.monotonic() - start)
def test_plot(seed=1337):
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
n_rows = 2 # len(replay_buffer)
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
for i in range(n_rows):
set_seed(seed + i)
batch_cpu = rb_cpu.sample(include_info=False)
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
print("episode cpu", batch_cpu["episode"].tolist())
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
frames = batch_cpu["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][0].imshow(frames[0])
set_seed(seed + i)
batch_cuda = rb_cuda.sample(include_info=False)
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
print("episode cuda", batch_cuda["episode"].tolist())
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
frames = batch_cuda["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][1].imshow(frames[0])
frames = batch_cuda["observation", "image"].type(torch.uint8)
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][2].imshow(frames[0])
axes[0][0].set_title("Software decoder")
axes[0][1].set_title("HW decoder")
axes[0][2].set_title("uint8")
plt.setp(axes, xticks=[], yticks=[])
plt.tight_layout()
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
# test_time()
test_plot()