From 9cdc24bc0e35e68aa4b6240c786fdcb9f038afe1 Mon Sep 17 00:00:00 2001 From: Cadene Date: Tue, 19 Mar 2024 13:41:49 +0000 Subject: [PATCH] WIP --- lerobot/common/datasets/transforms.py | 202 ++++++++++++++++++-------- lerobot/scripts/visualize_dataset.py | 3 +- test.py | 51 +------ 3 files changed, 152 insertions(+), 104 deletions(-) diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index b53215aa..18d6336a 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -1,6 +1,7 @@ from pathlib import Path from typing import Sequence +import einops import torch from tensordict import TensorDictBase from tensordict.nn import dispatch @@ -9,6 +10,46 @@ from torchaudio.io import StreamReader from torchrl.envs.transforms import Transform +def yuv_to_rgb(frames): + assert frames.dtype == torch.uint8 + assert frames.ndim == 4 + assert frames.shape[1] == 3 + + frames = frames.cpu().to(torch.float) + y = frames[..., 0, :, :] + u = frames[..., 1, :, :] + v = frames[..., 2, :, :] + + y /= 255 + u = u / 255 - 0.5 + v = v / 255 - 0.5 + + r = y + 1.13983 * v + g = y + -0.39465 * u - 0.58060 * v + b = y + 2.03211 * u + + rgb = torch.stack([r, g, b], 1) + rgb = (rgb * 255).clamp(0, 255).to(torch.uint8) + return rgb + + +def yuv_to_rgb_cv2(frames, return_hwc=True): + assert frames.dtype == torch.uint8 + assert frames.ndim == 4 + assert frames.shape[1] == 3 + frames = frames.cpu() + import cv2 + + frames = einops.rearrange(frames, "b c h w -> b h w c") + frames = frames.numpy() + frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames] + frames = [torch.from_numpy(frame) for frame in frames] + frames = torch.stack(frames) + if not return_hwc: + frames = einops.rearrange(frames, "b h w c -> b c h w") + return frames + + class ViewSliceHorizonTransform(Transform): invertible = False @@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform): self, data_dir: Path | str, device="cpu", + decoding_lib: str = "torchaudio", # format options are None=yuv420p (usually), rgb24, bgr24, etc. format: str | None = None, frame_rate: int | None = None, @@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform): ): self.data_dir = Path(data_dir) self.device = device + self.decoding_lib = decoding_lib self.format = format self.frame_rate = frame_rate self.width = width @@ -153,66 +196,17 @@ class DecodeVideoTransform(Transform): first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item() num_contiguous_frames = len(timestamps) - filter_desc = [] - video_stream_kwgs = { - "frames_per_chunk": num_contiguous_frames, - "buffer_chunk_size": num_contiguous_frames, - } - - # choice of decoder - if self.device == "cuda": - video_stream_kwgs["hw_accel"] = "cuda" - video_stream_kwgs["decoder"] = "h264_cuvid" - # video_stream_kwgs["decoder"] = "hevc_cuvid" - # video_stream_kwgs["decoder"] = "av1_cuvid" - # video_stream_kwgs["decoder"] = "ffv1_cuvid" + if self.decoding_lib == "torchaudio": + frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames) + elif self.decoding_lib == "ffmpegio": + frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames) + elif self.decoding_lib == "decord": + frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames) else: - video_stream_kwgs["decoder"] = "h264" - # video_stream_kwgs["decoder"] = "hevc" - # video_stream_kwgs["decoder"] = "av1" - # video_stream_kwgs["decoder"] = "ffv1" + raise ValueError(self.decoding_lib) - # resize - resize_width = self.width is not None - resize_height = self.height is not None - if resize_width or resize_height: - if self.device == "cuda": - assert resize_width and resize_height - video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"} - else: - scales = [] - if resize_width: - scales.append(f"width={self.width}") - if resize_height: - scales.append(f"height={self.height}") - filter_desc.append(f"scale={':'.join(scales)}") - - # choice of format - if self.format is not None: - if self.device == "cuda": - # TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp - raise NotImplementedError() - # filter_desc = f"scale=format={self.format}" - # filter_desc = f"scale_cuda=format={self.format}" - # filter_desc = f"scale_npp=format={self.format}" - else: - filter_desc.append(f"format=pix_fmts={self.format}") - - # choice of frame rate - if self.frame_rate is not None: - filter_desc.append(f"fps={self.frame_rate}") - - if len(filter_desc) > 0: - video_stream_kwgs["filter_desc"] = ",".join(filter_desc) - - # create a stream and load a certain number of frame at a certain frame rate - # TODO(rcadene): make sure it's the most optimal way to do it - # s = StreamReader(str(video_path).replace('.mp4','.mkv')) - s = StreamReader(str(video_path)) - s.seek(first_frame_ts) - s.add_video_stream(**video_stream_kwgs) - s.fill_buffer() - (frames,) = s.pop_chunks() + assert frames.ndim == 4 + assert frames.shape[1] == 3 b_frames.append(frames) @@ -222,3 +216,95 @@ class DecodeVideoTransform(Transform): # make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda" return td + + def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames): + filter_desc = [] + video_stream_kwgs = { + "frames_per_chunk": num_contiguous_frames, + "buffer_chunk_size": num_contiguous_frames, + } + + # choice of decoder + if self.device == "cuda": + video_stream_kwgs["hw_accel"] = "cuda" + video_stream_kwgs["decoder"] = "h264_cuvid" + # video_stream_kwgs["decoder"] = "hevc_cuvid" + # video_stream_kwgs["decoder"] = "av1_cuvid" + # video_stream_kwgs["decoder"] = "ffv1_cuvid" + else: + video_stream_kwgs["decoder"] = "h264" + # video_stream_kwgs["decoder"] = "hevc" + # video_stream_kwgs["decoder"] = "av1" + # video_stream_kwgs["decoder"] = "ffv1" + + # resize + resize_width = self.width is not None + resize_height = self.height is not None + if resize_width or resize_height: + if self.device == "cuda": + assert resize_width and resize_height + video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"} + else: + scales = [] + if resize_width: + scales.append(f"width={self.width}") + if resize_height: + scales.append(f"height={self.height}") + filter_desc.append(f"scale={':'.join(scales)}") + + # choice of format + if self.format is not None: + if self.device == "cuda": + # TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp + raise NotImplementedError() + # filter_desc = f"scale=format={self.format}" + # filter_desc = f"scale_cuda=format={self.format}" + # filter_desc = f"scale_npp=format={self.format}" + else: + filter_desc.append(f"format=pix_fmts={self.format}") + + # choice of frame rate + if self.frame_rate is not None: + filter_desc.append(f"fps={self.frame_rate}") + + filter_desc.append("scale=in_range=limited:out_range=full") + + if len(filter_desc) > 0: + video_stream_kwgs["filter_desc"] = ",".join(filter_desc) + + # create a stream and load a certain number of frame at a certain frame rate + # TODO(rcadene): make sure it's the most optimal way to do it + s = StreamReader(str(video_path)) + s.seek(first_frame_ts) + s.add_video_stream(**video_stream_kwgs) + s.fill_buffer() + (frames,) = s.pop_chunks() + + if "yuv" in self.format: + frames = yuv_to_rgb(frames) + return frames + + def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames): + import ffmpegio + + fs, frames = ffmpegio.video.read( + str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format + ) + frames = torch.from_numpy(frames) + frames = einops.rearrange(frames, "b h w c -> b c h w") + if self.device == "cuda": + frames = frames.to(self.device) + return frames + + def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames): + from decord import VideoReader, cpu, gpu + + with open(str(video_path), "rb") as f: + ctx = gpu if self.device == "cuda" else cpu + vr = VideoReader(f, ctx=ctx(0)) # noqa: F841 + raise NotImplementedError("Convert `first_frame_ts` into frame_id") + # frame_id = frame_ids[0].item() + # frames = vr.get_batch([frame_id]) + # frames = torch.from_numpy(frames.asnumpy()) + # frames = einops.rearrange(frames, "b h w c -> b c h w") + # return frames diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index f14196f9..35265947 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps): "-crf", "0", # Lossless option "-pix_fmt", - "yuv420p", # Specify pixel format + # "yuv420p", # Specify pixel format + "yuv444p", # Specify pixel format video_path, # video_path.replace(".mp4", ".mkv") ] diff --git a/test.py b/test.py index bf8f877c..7dd9a3dc 100644 --- a/test.py +++ b/test.py @@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12 NUM_ACTION_CHANNELS = 12 -def yuv_to_rgb(frames): - assert frames.dtype == torch.uint8 - assert frames.ndim == 4 - assert frames.shape[1] == 3 - - frames = frames.cpu().to(torch.float) - y = frames[..., 0, :, :] - u = frames[..., 1, :, :] - v = frames[..., 2, :, :] - - y /= 255 - u = u / 255 - 0.5 - v = v / 255 - 0.5 - - r = y + 1.13983 * v - g = y + -0.39465 * u - 0.58060 * v - b = y + 2.03211 * u - - rgb = torch.stack([r, g, b], 1) - rgb = (rgb * 255).clamp(0, 255).to(torch.uint8) - return rgb - - -def yuv_to_rgb_cv2(frames, return_hwc=True): - assert frames.dtype == torch.uint8 - assert frames.ndim == 4 - assert frames.shape[1] == 3 - frames = frames.cpu() - import cv2 - - frames = einops.rearrange(frames, "b c h w -> b h w c") - frames = frames.numpy() - frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames] - frames = [torch.from_numpy(frame) for frame in frames] - frames = torch.stack(frames) - if not return_hwc: - frames = einops.rearrange(frames, "b h w c -> b c h w") - return frames - - def count_frames(video_path): try: # Construct the ffprobe command to get the number of frames @@ -272,7 +232,7 @@ if __name__ == "__main__": if "cuvid" in k: print(f" - {k}") - def create_replay_buffer(device): + def create_replay_buffer(device, format=None): data_dir = Path("tmp/2024_03_17_data_video/pusht") num_slices = 1 @@ -293,6 +253,7 @@ if __name__ == "__main__": data_dir=data_dir, device=device, frame_rate=None, + format=format, in_keys=[("observation", "frame")], out_keys=[("observation", "frame", "data")], ), @@ -324,8 +285,8 @@ if __name__ == "__main__": print(time.monotonic() - start) def test_plot(seed=1337): - rb_cuda = create_replay_buffer(device="cuda") - rb_cpu = create_replay_buffer(device="cuda") + rb_cuda = create_replay_buffer(device="cuda", format="yuv444p") + rb_cpu = create_replay_buffer(device="cpu", format="yuv444p") n_rows = 2 # len(replay_buffer) fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0]) @@ -337,7 +298,7 @@ if __name__ == "__main__": print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist()) frames = batch_cpu["observation", "frame", "data"] frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") - frames = yuv_to_rgb(frames, return_hwc=True) + frames = einops.rearrange(frames, "bt c h w -> bt h w c") assert frames.shape[0] == 1 axes[i][0].imshow(frames[0]) @@ -348,7 +309,7 @@ if __name__ == "__main__": print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist()) frames = batch_cuda["observation", "frame", "data"] frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") - frames = yuv_to_rgb(frames, return_hwc=True) + frames = einops.rearrange(frames, "bt c h w -> bt h w c") assert frames.shape[0] == 1 axes[i][1].imshow(frames[0])