forked from tangger/lerobot
Compare commits
6 Commits
hf-papers
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cdc24bc0e | ||
|
|
a346469a5a | ||
|
|
2bef00c317 | ||
|
|
9954994a4b | ||
|
|
0fc94b81b3 | ||
|
|
d32a279435 |
@@ -14,6 +14,8 @@ from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform
|
||||
|
||||
|
||||
class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
@@ -33,7 +35,14 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
self.dataset_id = dataset_id
|
||||
self.shuffle = shuffle
|
||||
self.root = root
|
||||
storage = self._download_or_load_dataset()
|
||||
storage, meta_data = self._download_or_load_dataset()
|
||||
|
||||
if transform is not None and "video_id_to_path" in meta_data:
|
||||
# hack to access video paths
|
||||
assert isinstance(transform, Compose)
|
||||
for tf in transform:
|
||||
if isinstance(tf, DecodeVideoTransform):
|
||||
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
@@ -99,7 +108,13 @@ class AbstractExperienceReplay(TensorDictReplayBuffer):
|
||||
self.data_dir = Path(snapshot_download(repo_id=f"cadene/{self.dataset_id}", repo_type="dataset"))
|
||||
else:
|
||||
self.data_dir = self.root / self.dataset_id
|
||||
return TensorStorage(TensorDict.load_memmap(self.data_dir))
|
||||
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.data_dir / "replay_buffer"))
|
||||
# required to not send cuda frames to cpu by default
|
||||
storage._storage.clear_device_()
|
||||
|
||||
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
||||
return storage, meta_data
|
||||
|
||||
def _compute_stats(self, num_batch=100, batch_size=32):
|
||||
rb = TensorDictReplayBuffer(
|
||||
|
||||
@@ -87,17 +87,30 @@ def make_offline_buffer(
|
||||
prefetch=prefetch if isinstance(prefetch, int) else None,
|
||||
)
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
transforms = []
|
||||
|
||||
transforms = [Prod(in_keys=img_keys, prod=1 / 255)]
|
||||
# transforms = [
|
||||
# ViewSliceHorizonTransform(num_slices, cfg.policy.horizon),
|
||||
# KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
# DecodeVideoTransform(
|
||||
# data_dir=offline_buffer.data_dir,
|
||||
# device=cfg.device,
|
||||
# frame_rate=None,
|
||||
# in_keys=[("observation", "frame")],
|
||||
# out_keys=[("observation", "frame", "data")],
|
||||
# ),
|
||||
# ]
|
||||
|
||||
if normalize:
|
||||
if cfg.policy.name == "tdmpc":
|
||||
img_keys = []
|
||||
for key in offline_buffer.image_keys:
|
||||
img_keys.append(("next", *key))
|
||||
img_keys += offline_buffer.image_keys
|
||||
else:
|
||||
img_keys = offline_buffer.image_keys
|
||||
transforms.append(Prod(in_keys=img_keys, prod=1 / 255))
|
||||
|
||||
# TODO(rcadene): make normalization strategy configurable between mean_std, min_max, manual_min_max, min_max_from_spec
|
||||
stats = offline_buffer.compute_or_load_stats()
|
||||
|
||||
|
||||
310
lerobot/common/datasets/transforms.py
Normal file
310
lerobot/common/datasets/transforms.py
Normal file
@@ -0,0 +1,310 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchaudio.io import StreamReader
|
||||
from torchrl.envs.transforms import Transform
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
class ViewSliceHorizonTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(self, num_slices, horizon):
|
||||
super().__init__()
|
||||
self.num_slices = num_slices
|
||||
self.horizon = horizon
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
td = td.view(self.num_slices, self.horizon)
|
||||
return td
|
||||
|
||||
|
||||
class KeepFrames(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positions,
|
||||
in_keys: Sequence[NestedKey],
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
):
|
||||
if isinstance(positions, list):
|
||||
assert isinstance(positions[0], int)
|
||||
# TODO(rcadene)L add support for `isinstance(positions, int)`?
|
||||
|
||||
self.positions = positions
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# we need set batch_size=[] before assigning a different shape to td[outkey]
|
||||
td.batch_size = []
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
td[outkey] = td[inkey][:, self.positions]
|
||||
return td
|
||||
|
||||
|
||||
class DecodeVideoTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: Path | str,
|
||||
device="cpu",
|
||||
decoding_lib: str = "torchaudio",
|
||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||
format: str | None = None,
|
||||
frame_rate: int | None = None,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.device = device
|
||||
self.decoding_lib = decoding_lib
|
||||
self.format = format
|
||||
self.frame_rate = frame_rate
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.video_id_to_path = None
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
|
||||
def set_video_id_to_path(self, video_id_to_path):
|
||||
self.video_id_to_path = video_id_to_path
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
assert (
|
||||
self.video_id_to_path is not None
|
||||
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
||||
bsize = len(td[inkey]) # num episodes in the batch
|
||||
b_frames = []
|
||||
for i in range(bsize):
|
||||
assert (
|
||||
td["observation", "frame", "video_id"].ndim == 3
|
||||
and td["observation", "frame", "video_id"].shape[2] == 1
|
||||
), "We expect 2 dims. Respectively, number of episodes in the batch, number of observations, 1"
|
||||
|
||||
ep_video_ids = td[inkey]["video_id"][i]
|
||||
timestamps = td[inkey]["timestamp"][i]
|
||||
frame_ids = td["frame_id"][i]
|
||||
|
||||
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
|
||||
assert unique_video_id
|
||||
|
||||
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
|
||||
assert is_ascending
|
||||
|
||||
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
|
||||
assert is_contiguous
|
||||
|
||||
FIRST_FRAME = 0 # noqa: N806
|
||||
video_id = ep_video_ids[FIRST_FRAME].squeeze(0).item()
|
||||
video_path = self.data_dir / self.video_id_to_path[video_id]
|
||||
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
|
||||
num_contiguous_frames = len(timestamps)
|
||||
|
||||
if self.decoding_lib == "torchaudio":
|
||||
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "ffmpegio":
|
||||
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "decord":
|
||||
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
|
||||
else:
|
||||
raise ValueError(self.decoding_lib)
|
||||
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
b_frames.append(frames)
|
||||
|
||||
td[outkey] = torch.stack(b_frames)
|
||||
|
||||
if self.device == "cuda":
|
||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||
return td
|
||||
|
||||
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
# video_stream_kwgs["decoder"] = "hevc"
|
||||
# video_stream_kwgs["decoder"] = "av1"
|
||||
# video_stream_kwgs["decoder"] = "ffv1"
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
filter_desc.append("scale=in_range=limited:out_range=full")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
s = StreamReader(str(video_path))
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
|
||||
if "yuv" in self.format:
|
||||
frames = yuv_to_rgb(frames)
|
||||
return frames
|
||||
|
||||
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
import ffmpegio
|
||||
|
||||
fs, frames = ffmpegio.video.read(
|
||||
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
|
||||
)
|
||||
frames = torch.from_numpy(frames)
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
if self.device == "cuda":
|
||||
frames = frames.to(self.device)
|
||||
return frames
|
||||
|
||||
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
from decord import VideoReader, cpu, gpu
|
||||
|
||||
with open(str(video_path), "rb") as f:
|
||||
ctx = gpu if self.device == "cuda" else cpu
|
||||
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
|
||||
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
|
||||
# frame_id = frame_ids[0].item()
|
||||
# frames = vr.get_batch([frame_id])
|
||||
# frames = torch.from_numpy(frames.asnumpy())
|
||||
# frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
# return frames
|
||||
@@ -3,6 +3,7 @@ import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
|
||||
@@ -28,3 +29,26 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.14 * v
|
||||
g = y + -0.396 * u - 0.581 * v
|
||||
b = y + 2.029 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
@@ -54,7 +54,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
def update(self, replay_buffer, step):
|
||||
del step
|
||||
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
self.train()
|
||||
|
||||
@@ -104,7 +104,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
loss = self.compute_loss(batch)
|
||||
loss.backward()
|
||||
@@ -125,7 +125,7 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
# "lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
@@ -188,8 +188,8 @@ class MetricLogger:
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
start_time = time.monotonic()
|
||||
end = time.monotonic()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
@@ -218,9 +218,9 @@ class MetricLogger:
|
||||
)
|
||||
mega_b = 1024.0 * 1024.0
|
||||
for i, obj in enumerate(iterable):
|
||||
data_time.update(time.time() - end)
|
||||
data_time.update(time.monotonic() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
iter_time.update(time.monotonic() - end)
|
||||
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
@@ -247,8 +247,8 @@ class MetricLogger:
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
end = time.monotonic()
|
||||
total_time = time.monotonic() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print("{} Total time: {} ({:.4f} s / it)".format(header, total_time_str, total_time / len(iterable)))
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ class DiffusionPolicy(nn.Module):
|
||||
return action
|
||||
|
||||
def update(self, replay_buffer, step):
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
self.diffusion.train()
|
||||
|
||||
@@ -158,7 +158,7 @@ class DiffusionPolicy(nn.Module):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
batch = process_batch(batch, self.cfg.horizon, num_slices)
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
loss.backward()
|
||||
@@ -181,7 +181,7 @@ class DiffusionPolicy(nn.Module):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.lr_scheduler.get_last_lr()[0],
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
|
||||
# TODO(rcadene): remove hardcoding
|
||||
|
||||
@@ -291,7 +291,7 @@ class TDMPC(nn.Module):
|
||||
|
||||
def update(self, replay_buffer, step, demo_buffer=None):
|
||||
"""Main update function. Corresponds to one iteration of the model learning."""
|
||||
start_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
num_slices = self.cfg.batch_size
|
||||
batch_size = self.cfg.horizon * num_slices
|
||||
@@ -405,7 +405,7 @@ class TDMPC(nn.Module):
|
||||
self.std = h.linear_schedule(self.cfg.std_schedule, step)
|
||||
self.model.train()
|
||||
|
||||
data_s = time.time() - start_time
|
||||
data_s = time.monotonic() - start_time
|
||||
|
||||
# Compute targets
|
||||
with torch.no_grad():
|
||||
@@ -501,7 +501,7 @@ class TDMPC(nn.Module):
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"data_s": data_s,
|
||||
"update_s": time.time() - start_time,
|
||||
"update_s": time.monotonic() - start_time,
|
||||
}
|
||||
info["demo_batch_size"] = demo_batch_size
|
||||
info["expectile"] = expectile
|
||||
|
||||
129
lerobot/scripts/convert_dataset_uint8_to_mp4.py
Normal file
129
lerobot/scripts/convert_dataset_uint8_to_mp4.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
usage: `python lerobot/scripts/convert_dataset_uint8_to_mp4.py --in-data-dir data/pusht --out-data-dir tests/data/pusht`
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
|
||||
def convert_dataset_uint8_to_mp4(in_data_dir, out_data_dir, fps, overwrite_num_frames=None):
|
||||
assert fps is not None and isinstance(fps, float)
|
||||
# load full dataset as a tensor dict
|
||||
in_td_data = TensorDict.load_memmap(in_data_dir)
|
||||
|
||||
out_data_dir = Path(out_data_dir)
|
||||
# use 1 frame to know the specification of the dataset
|
||||
# and copy it over `n` frames in the test artifact directory
|
||||
out_rb_dir = out_data_dir / "replay_buffer"
|
||||
if out_rb_dir.exists():
|
||||
shutil.rmtree(out_rb_dir)
|
||||
|
||||
num_frames = len(in_td_data) if overwrite_num_frames is None else overwrite_num_frames
|
||||
|
||||
# del in_td_data["observation", "image"]
|
||||
# del in_td_data["next", "observation", "image"]
|
||||
|
||||
out_td_data = in_td_data[0].memmap_().clone()
|
||||
|
||||
out_td_data["observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
|
||||
out_td_data["observation", "frame", "timestamp"] = torch.zeros(1)
|
||||
out_td_data["next", "observation", "frame", "video_id"] = torch.zeros(1, dtype=torch.int)
|
||||
out_td_data["next", "observation", "frame", "timestamp"] = torch.zeros(1)
|
||||
|
||||
out_td_data = out_td_data.expand(num_frames)
|
||||
out_td_data = out_td_data.memmap_like(out_rb_dir)
|
||||
|
||||
out_vid_dir = out_data_dir / "videos"
|
||||
out_vid_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_id_to_path = {}
|
||||
|
||||
for key in out_td_data.keys(include_nested=True, leaves_only=True):
|
||||
if in_td_data.get(key, None) is None:
|
||||
continue
|
||||
if overwrite_num_frames is None:
|
||||
out_td_data[key].copy_(in_td_data[key].clone())
|
||||
else:
|
||||
out_td_data[key][:num_frames].copy_(in_td_data[key][:num_frames].clone())
|
||||
|
||||
for i in range(num_frames):
|
||||
video_id = in_td_data["episode"][i]
|
||||
frame_id = in_td_data["frame_id"][i]
|
||||
|
||||
out_td_data["observation", "frame", "video_id"][i] = video_id
|
||||
out_td_data["observation", "frame", "timestamp"][i] = frame_id / fps
|
||||
out_td_data["next", "observation", "frame", "video_id"][i] = video_id
|
||||
out_td_data["next", "observation", "frame", "timestamp"][i] = (frame_id + 1) / fps
|
||||
|
||||
video_id = video_id.item()
|
||||
if video_id not in video_id_to_path:
|
||||
video_id_to_path[video_id] = f"videos/episode_{video_id}.mp4"
|
||||
|
||||
# copy the first `n` frames so that we have real data
|
||||
|
||||
# make sure everything has been properly written
|
||||
out_td_data.lock_()
|
||||
|
||||
# copy the full statistics of dataset since it's pretty small
|
||||
in_stats_path = Path(in_data_dir) / "stats.pth"
|
||||
|
||||
out_stats_path = Path(out_data_dir) / "stats.pth"
|
||||
shutil.copy(in_stats_path, out_stats_path)
|
||||
|
||||
meta_data = {
|
||||
"video_id_to_path": video_id_to_path,
|
||||
}
|
||||
torch.save(meta_data, out_data_dir / "meta_data.pth")
|
||||
|
||||
|
||||
# def write_to_mp4():
|
||||
# buffer = io.BytesIO()
|
||||
# swriter = StreamWriter(buffer, format="mp4")
|
||||
|
||||
# device = "cuda"
|
||||
|
||||
# c,h,w = in_td_data[0]["observation", "image"].shape
|
||||
|
||||
# swriter.add_video_stream(
|
||||
# frame_rate=fps,
|
||||
# width=w,
|
||||
# height=h,
|
||||
# # frame_rate=30000 / 1001,
|
||||
# format="yuv444p",
|
||||
# encoder="h264_nvenc",
|
||||
# encoder_format="yuv444p",
|
||||
# hw_accel=device,
|
||||
# )
|
||||
|
||||
# for i in range(num_frames):
|
||||
# ep_id = in_td_data[i]["episode"]
|
||||
# data = in_td_data[i]["observation", "image"]
|
||||
# with swriter.open():
|
||||
# t0 = time.monotonic()
|
||||
# data = data.to(device)
|
||||
# swriter.write_video_chunk(0, data)
|
||||
# elapsed = time.monotonic() - t0
|
||||
# size = buffer.tell()
|
||||
# print(f"{elapsed=}")
|
||||
# print(f"{size=}")
|
||||
# buffer.seek(0)
|
||||
# video = buffer.read()
|
||||
|
||||
# vid_path = out_vid_dir / f"episode_{ep_id}.mp4"
|
||||
# with open(vid_path, 'wb+') as f:
|
||||
# f.write(video)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Create dataset")
|
||||
|
||||
parser.add_argument("--in-data-dir", type=str, help="Path to input data")
|
||||
parser.add_argument("--out-data-dir", type=str, help="Path to save the output data")
|
||||
parser.add_argument("--fps", type=float)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_dataset_uint8_to_mp4(args.in_data_dir, args.out_data_dir, args.fps)
|
||||
@@ -32,7 +32,7 @@ def eval_policy(
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
@@ -85,8 +85,8 @@ def eval_policy(
|
||||
"avg_sum_reward": np.nanmean(sum_rewards),
|
||||
"avg_max_reward": np.nanmean(max_rewards),
|
||||
"pc_success": np.nanmean(successes) * 100,
|
||||
"eval_s": time.time() - start,
|
||||
"eval_ep_s": (time.time() - start) / num_episodes,
|
||||
"eval_s": time.monotonic() - start,
|
||||
"eval_ep_s": (time.monotonic() - start) / num_episodes,
|
||||
}
|
||||
if return_first_video:
|
||||
return info, first_video
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
@@ -26,9 +29,49 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
if frames.dtype != torch.uint8:
|
||||
logging.warning(f"frames are expected to be uint8 to {frames.dtype}")
|
||||
frames = frames.type(torch.uint8)
|
||||
|
||||
_, _, h, w = frames.shape
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
img_dir = Path(video_path.split(".")[0])
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(len(frames)):
|
||||
imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i])
|
||||
|
||||
ffmpeg_command = [
|
||||
"ffmpeg",
|
||||
"-r",
|
||||
str(fps),
|
||||
"-f",
|
||||
"image2",
|
||||
"-s",
|
||||
f"{w}x{h}",
|
||||
"-i",
|
||||
str(img_dir / "frame_%04d.png"),
|
||||
"-vcodec",
|
||||
"libx264",
|
||||
#'-vcodec', 'libx265',
|
||||
#'-vcodec', 'libaom-av1',
|
||||
"-crf",
|
||||
"0", # Lossless option
|
||||
"-pix_fmt",
|
||||
# "yuv420p", # Specify pixel format
|
||||
"yuv444p", # Specify pixel format
|
||||
video_path,
|
||||
# video_path.replace(".mp4", ".mkv")
|
||||
]
|
||||
subprocess.run(ffmpeg_command, check=True)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# clean temporary image directory
|
||||
# shutil.rmtree(img_dir)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
@@ -61,7 +104,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
new_episode = ep_idx > current_ep_idx
|
||||
|
||||
if ep_idx < current_ep_idx:
|
||||
break
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
@@ -71,7 +117,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir = Path(out_dir) / "videos"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
|
||||
68
poetry.lock
generated
68
poetry.lock
generated
@@ -2684,13 +2684,13 @@ test = ["asv", "gmpy2", "hypothesis", "mpmath", "pooch", "pytest", "pytest-cov",
|
||||
|
||||
[[package]]
|
||||
name = "sentry-sdk"
|
||||
version = "1.41.0"
|
||||
version = "1.42.0"
|
||||
description = "Python client for Sentry (https://sentry.io)"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "sentry-sdk-1.41.0.tar.gz", hash = "sha256:4f2d6c43c07925d8cd10dfbd0970ea7cb784f70e79523cca9dbcd72df38e5a46"},
|
||||
{file = "sentry_sdk-1.41.0-py2.py3-none-any.whl", hash = "sha256:be4f8f4b29a80b6a3b71f0f31487beb9e296391da20af8504498a328befed53f"},
|
||||
{file = "sentry-sdk-1.42.0.tar.gz", hash = "sha256:4a8364b8f7edbf47f95f7163e48334c96100d9c098f0ae6606e2e18183c223e6"},
|
||||
{file = "sentry_sdk-1.42.0-py2.py3-none-any.whl", hash = "sha256:a654ee7e497a3f5f6368b36d4f04baeab1fe92b3105f7f6965d6ef0de35a9ba4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -2714,6 +2714,7 @@ grpcio = ["grpcio (>=1.21.1)"]
|
||||
httpx = ["httpx (>=0.16.0)"]
|
||||
huey = ["huey (>=2)"]
|
||||
loguru = ["loguru (>=0.5)"]
|
||||
openai = ["openai (>=1.0.0)", "tiktoken (>=0.3.0)"]
|
||||
opentelemetry = ["opentelemetry-distro (>=0.35b0)"]
|
||||
opentelemetry-experimental = ["opentelemetry-distro (>=0.40b0,<1.0)", "opentelemetry-instrumentation-aiohttp-client (>=0.40b0,<1.0)", "opentelemetry-instrumentation-django (>=0.40b0,<1.0)", "opentelemetry-instrumentation-fastapi (>=0.40b0,<1.0)", "opentelemetry-instrumentation-flask (>=0.40b0,<1.0)", "opentelemetry-instrumentation-requests (>=0.40b0,<1.0)", "opentelemetry-instrumentation-sqlite3 (>=0.40b0,<1.0)", "opentelemetry-instrumentation-urllib (>=0.40b0,<1.0)"]
|
||||
pure-eval = ["asttokens", "executing", "pure-eval"]
|
||||
@@ -2829,18 +2830,18 @@ test = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "69.1.1"
|
||||
version = "69.2.0"
|
||||
description = "Easily download, build, install, upgrade, and uninstall Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "setuptools-69.1.1-py3-none-any.whl", hash = "sha256:02fa291a0471b3a18b2b2481ed902af520c69e8ae0919c13da936542754b4c56"},
|
||||
{file = "setuptools-69.1.1.tar.gz", hash = "sha256:5c0806c7d9af348e6dd3777b4f4dbb42c7ad85b190104837488eab9a7c945cf8"},
|
||||
{file = "setuptools-69.2.0-py3-none-any.whl", hash = "sha256:c21c49fb1042386df081cb5d86759792ab89efca84cf114889191cd09aacc80c"},
|
||||
{file = "setuptools-69.2.0.tar.gz", hash = "sha256:0ff4183f8f42cd8fa3acea16c45205521a4ef28f73c6391d8a25e92893134f2e"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff (>=0.2.1)", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
|
||||
testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.2)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
|
||||
|
||||
[[package]]
|
||||
@@ -2949,7 +2950,7 @@ mpmath = ">=0.19"
|
||||
|
||||
[[package]]
|
||||
name = "tensordict"
|
||||
version = "0.4.0+551331d"
|
||||
version = "0.4.0+6a56ecd"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
@@ -2970,7 +2971,7 @@ tests = ["pytest", "pytest-benchmark", "pytest-instafail", "pytest-rerunfailures
|
||||
type = "git"
|
||||
url = "https://github.com/pytorch/tensordict"
|
||||
reference = "HEAD"
|
||||
resolved_reference = "ed22554d6860731610df784b2f5d09f31d3dbc7a"
|
||||
resolved_reference = "6a56ecd728757feee387f946b7da66dd452b739b"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
@@ -3072,6 +3073,43 @@ typing-extensions = ">=4.8.0"
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.9.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchaudio"
|
||||
version = "2.2.1"
|
||||
description = "An audio package for PyTorch"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:580eefd764a01a64d5b6aa260c0c47974be6a6964892d54029a73b17f4611fcd"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8ad55c2069b27bbe18e14783a202e3f3f8082fe9e59281436ba797edb0fc94d5"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:55d23986254f7af689695f3fc214c4aa3e73dc931289ecdba7262d73fea7af7a"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:b916b7764698ba9319aa3b25519139892de8665d84438969bac5e1d8578c6a11"},
|
||||
{file = "torchaudio-2.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:281cd4bdb9e65c0618a028b809df9e06f9bd9592aeef8f2b37b4d8a788ce5f2b"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:274cb8474bc1e56b768ef347d3188661c5a9d5e68e2df56fc0aff11cc73c916a"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e62c27b17672cc2bdd9663681e533000f9c0984e6a0f3d455f7051bc005bb02"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7df7d5d9100116be38ff7b27b628820dca4a9e3fe79394605141d339e3b3e46d"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:20b2965db4f843021636f53d3fab1075c3f8959c450c647629124d24c7e6cbb0"},
|
||||
{file = "torchaudio-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:63dd0e840bcf2e4aceb7a98daccfaf7a2a5b3a927647b98bbef449b0b190f2cc"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2c232dc8bee97d303b90833ba934d8905eb7326456236efcd9fa71ccb92fd363"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2419387cf04d33047369337bf09c00c2a7673a8f52f80258454c7eca7d205d23"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2483c0620a68a136359ae90c893608ad5cd73091fb0351b94d33af126a0e3d67"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:bd389f33b7dbfc44e5f4070fc6db00cc560992bea8378a952889acfd772b7022"},
|
||||
{file = "torchaudio-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:d5af725a327b79f3bd8389c53ec51554ee003c18434fc47e68da49b09900132e"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:81ef88d7693e3b99007d1ee742fd81b9a92399ecbf88eb7ed69949443005ffba"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f487a7d3177ae6af016750850ee93788e880218a1a310bc6c76901e212f91cd3"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:bee5478ec2cb7d0eaa97023d817aa4914010e1ab0c266f64ef1b0db893aceb49"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:a4462b3f214f60b6b8f78e12a4cf1291c9bc353deed709ac3dfdedbed513a7a3"},
|
||||
{file = "torchaudio-2.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:4bc43d11d9e086f0dfb29f6ea99517d8ec06fa80d97283f2c8b83c4cd467dd1a"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:0339fe78ed9c29f704296761b28bb055b5350625ff503ad781704397934e6b58"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68b1d9f8ffe9b26ef04e80d82ae2dc2f74b1a1eb64c3e8ad21b525802b3bc7ac"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:3962fea5d2511c9ab2b1dd515b45ec44d0c28e51f3b05c0b9fa7bbcc3c213bc1"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:cb2da08abb7b68dc7b0105748b1a736dd33329f841374013ec02c54e04bedf29"},
|
||||
{file = "torchaudio-2.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:54996977ab1c875729e8dedc4695609ca58f876c23756c79979c6b50136b3385"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
torch = "2.2.1"
|
||||
|
||||
[[package]]
|
||||
name = "torchrl"
|
||||
version = "0.4.0+13bef42"
|
||||
@@ -3311,20 +3349,20 @@ jupyter = ["ipytree (>=0.2.2)", "ipywidgets (>=8.0.0)", "notebook"]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.17.0"
|
||||
version = "3.18.0"
|
||||
description = "Backport of pathlib-compatible object wrapper for zip files"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "zipp-3.17.0-py3-none-any.whl", hash = "sha256:0e923e726174922dce09c53c59ad483ff7bbb8e572e00c7f7c46b88556409f31"},
|
||||
{file = "zipp-3.17.0.tar.gz", hash = "sha256:84e64a1c28cf7e91ed2078bb8cc8c259cb19b76942096c8d7b84947690cabaf0"},
|
||||
{file = "zipp-3.18.0-py3-none-any.whl", hash = "sha256:c1bb803ed69d2cce2373152797064f7e79bc43f0a3748eb494096a867e0ebf79"},
|
||||
{file = "zipp-3.18.0.tar.gz", hash = "sha256:df8d042b02765029a09b157efd8e820451045890acc30f8e37dd2f94a060221f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (<7.2.5)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"]
|
||||
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"]
|
||||
testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "ee86b84a795e6a3e9c2d79f244a87b55589adbe46d549ac38adf48be27c04cf9"
|
||||
content-hash = "e0c9fa6894aaa917493f81028c1bcc3fff8c56d9025681af44534fc3dbe7646e"
|
||||
|
||||
@@ -51,6 +51,7 @@ torchvision = "^0.17.1"
|
||||
h5py = "^3.10.0"
|
||||
dm-control = "1.0.14"
|
||||
huggingface-hub = {extras = ["hf-transfer"], version = "^0.21.4"}
|
||||
torchaudio = "^2.2.1"
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
330
test.py
Normal file
330
test.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# TODO(rcadene): add tests
|
||||
# TODO(rcadene): what is the best format to store/load videos?
|
||||
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchrl
|
||||
from matplotlib import pyplot as plt
|
||||
from tensordict import TensorDict
|
||||
from torchaudio.utils import ffmpeg_utils
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
NUM_STATE_CHANNELS = 12
|
||||
NUM_ACTION_CHANNELS = 12
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=nb_frames",
|
||||
"-of",
|
||||
"default=nokey=1:noprint_wrappers=1",
|
||||
video_path,
|
||||
]
|
||||
|
||||
# Execute the ffprobe command and capture the output
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# Convert the output to an integer
|
||||
num_frames = int(result.stdout.strip())
|
||||
|
||||
return num_frames
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return -1
|
||||
|
||||
|
||||
def get_frame_rate(video_path):
|
||||
try:
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate",
|
||||
"-of",
|
||||
"default=nokey=1:noprint_wrappers=1",
|
||||
video_path,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
|
||||
# The frame rate is typically represented as a fraction (e.g., "30000/1001").
|
||||
# To convert it to a float, we can evaluate the fraction.
|
||||
frame_rate = eval(result.stdout.strip())
|
||||
|
||||
return frame_rate
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return -1
|
||||
|
||||
|
||||
def get_frame_timestamps(frame_rate, num_frames):
|
||||
timestamps = [(1 / frame_rate) * i for i in range(num_frames)]
|
||||
return timestamps
|
||||
|
||||
|
||||
# class ClearDeviceTransform(Transform):
|
||||
# invertible = False
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__()
|
||||
|
||||
# def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# # _reset is called once when the environment reset to normalize the first observation
|
||||
# tensordict_reset = self._call(tensordict_reset)
|
||||
# return tensordict_reset
|
||||
|
||||
# @dispatch(source="in_keys", dest="out_keys")
|
||||
# def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
# return self._call(tensordict)
|
||||
|
||||
# def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# td.clear_device_()
|
||||
# return td
|
||||
|
||||
|
||||
class VideoExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = None,
|
||||
*,
|
||||
root: Path = None,
|
||||
pin_memory: bool = False,
|
||||
prefetch: int = None,
|
||||
sampler: SliceSampler = None,
|
||||
collate_fn: Callable = None,
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.data_dir = root
|
||||
self.rb_dir = self.data_dir / "replay_buffer"
|
||||
|
||||
storage, meta_data = self._load_or_download()
|
||||
|
||||
# hack to access video paths
|
||||
assert isinstance(transform, Compose)
|
||||
for tf in transform:
|
||||
if isinstance(tf, DecodeVideoTransform):
|
||||
tf.set_video_id_to_path(meta_data["video_id_to_path"])
|
||||
|
||||
super().__init__(
|
||||
storage=storage,
|
||||
sampler=sampler,
|
||||
writer=ImmutableDatasetWriter() if writer is None else writer,
|
||||
collate_fn=_collate_id if collate_fn is None else collate_fn,
|
||||
pin_memory=pin_memory,
|
||||
prefetch=prefetch,
|
||||
batch_size=batch_size,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
def _load_or_download(self, force_download=False):
|
||||
if not force_download and self.data_dir.exists():
|
||||
storage = TensorStorage(TensorDict.load_memmap(self.rb_dir))
|
||||
meta_data = torch.load(self.data_dir / "meta_data.pth")
|
||||
else:
|
||||
storage, meta_data = self._download()
|
||||
torch.save(meta_data, self.data_dir / "meta_data.pth")
|
||||
|
||||
# required to not send cuda frames to cpu by default
|
||||
storage._storage.clear_device_()
|
||||
return storage, meta_data
|
||||
|
||||
def _download(self):
|
||||
num_episodes = 1
|
||||
video_id_to_path = {}
|
||||
for episode_id in range(num_episodes):
|
||||
video_path = torchaudio.utils.download_asset(
|
||||
"tutorial-assets/stream-api/NASAs_Most_Scientifically_Complex_Space_Observatory_Requires_Precision-MP4_small.mp4"
|
||||
)
|
||||
# several episodes can belong to the same video
|
||||
video_id = episode_id
|
||||
video_id_to_path[video_id] = video_path
|
||||
|
||||
print(f"{video_path=}")
|
||||
num_frames = count_frames(video_path)
|
||||
print(f"{num_frames=}")
|
||||
frame_rate = get_frame_rate(video_path)
|
||||
print(f"{frame_rate=}")
|
||||
|
||||
frame_timestamps = get_frame_timestamps(frame_rate, num_frames)
|
||||
|
||||
reward = torch.zeros(num_frames, 1, dtype=torch.float32)
|
||||
success = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
done = torch.zeros(num_frames, 1, dtype=torch.bool)
|
||||
state = torch.randn(num_frames, NUM_STATE_CHANNELS, dtype=torch.float32)
|
||||
action = torch.randn(num_frames, NUM_ACTION_CHANNELS, dtype=torch.float32)
|
||||
timestamp = torch.tensor(frame_timestamps)
|
||||
frame_id = torch.arange(0, num_frames, 1)
|
||||
episode_id_tensor = torch.tensor([episode_id] * num_frames, dtype=torch.int)
|
||||
video_id_tensor = torch.tensor([video_id] * num_frames, dtype=torch.int)
|
||||
|
||||
# last step of demonstration is considered done
|
||||
done[-1] = True
|
||||
|
||||
ep_td = TensorDict(
|
||||
{
|
||||
("observation", "frame", "video_id"): video_id_tensor[:-1],
|
||||
("observation", "frame", "timestamp"): timestamp[:-1],
|
||||
("observation", "state"): state[:-1],
|
||||
"action": action[:-1],
|
||||
"episode": episode_id_tensor[:-1],
|
||||
"frame_id": frame_id[:-1],
|
||||
("next", "observation", "frame", "video_id"): video_id_tensor[1:],
|
||||
("next", "observation", "frame", "timestamp"): timestamp[1:],
|
||||
("next", "observation", "state"): state[1:],
|
||||
("next", "reward"): reward[1:],
|
||||
("next", "done"): done[1:],
|
||||
("next", "success"): success[1:],
|
||||
},
|
||||
batch_size=num_frames - 1,
|
||||
)
|
||||
|
||||
# TODO:
|
||||
total_frames = num_frames - 1
|
||||
|
||||
if episode_id == 0:
|
||||
# hack to initialize tensordict data structure to store episodes
|
||||
td_data = ep_td[0].expand(total_frames).memmap_like(self.rb_dir)
|
||||
|
||||
td_data[:] = ep_td
|
||||
|
||||
meta_data = {
|
||||
"video_id_to_path": video_id_to_path,
|
||||
}
|
||||
|
||||
return TensorStorage(td_data.lock_()), meta_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
import tqdm
|
||||
|
||||
print("FFmpeg Library versions:")
|
||||
for k, ver in ffmpeg_utils.get_versions().items():
|
||||
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
|
||||
|
||||
print("Available NVDEC Decoders:")
|
||||
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device, format=None):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
horizon = 2
|
||||
batch_size = num_slices * horizon
|
||||
|
||||
sampler = SliceSamplerWithoutReplacement(
|
||||
num_slices=num_slices,
|
||||
strict_length=True,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
transforms = [
|
||||
# ClearDeviceTransform(),
|
||||
ViewSliceHorizonTransform(num_slices, horizon),
|
||||
KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
DecodeVideoTransform(
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
format=format,
|
||||
in_keys=[("observation", "frame")],
|
||||
out_keys=[("observation", "frame", "data")],
|
||||
),
|
||||
]
|
||||
|
||||
replay_buffer = VideoExperienceReplay(
|
||||
root=data_dir,
|
||||
batch_size=batch_size,
|
||||
# prefetch=4,
|
||||
transform=Compose(*transforms),
|
||||
sampler=sampler,
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
def test_time():
|
||||
replay_buffer = create_replay_buffer(device="cuda")
|
||||
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(2)):
|
||||
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.monotonic() - start)
|
||||
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(10)):
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
||||
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
for i in range(n_rows):
|
||||
set_seed(seed + i)
|
||||
batch_cpu = rb_cpu.sample(include_info=False)
|
||||
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
|
||||
print("episode cpu", batch_cpu["episode"].tolist())
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
set_seed(seed + i)
|
||||
batch_cuda = rb_cuda.sample(include_info=False)
|
||||
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
|
||||
print("episode cuda", batch_cuda["episode"].tolist())
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
frames = batch_cuda["observation", "image"].type(torch.uint8)
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][2].imshow(frames[0])
|
||||
|
||||
axes[0][0].set_title("Software decoder")
|
||||
axes[0][1].set_title("HW decoder")
|
||||
axes[0][2].set_title("uint8")
|
||||
plt.setp(axes, xticks=[], yticks=[])
|
||||
plt.tight_layout()
|
||||
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
||||
|
||||
# test_time()
|
||||
test_plot()
|
||||
Reference in New Issue
Block a user