This commit is contained in:
Cadene
2024-03-19 13:41:49 +00:00
parent a346469a5a
commit 9cdc24bc0e
3 changed files with 152 additions and 104 deletions

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Sequence
import einops
import torch
from tensordict import TensorDictBase
from tensordict.nn import dispatch
@@ -9,6 +10,46 @@ from torchaudio.io import StreamReader
from torchrl.envs.transforms import Transform
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
class ViewSliceHorizonTransform(Transform):
invertible = False
@@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform):
self,
data_dir: Path | str,
device="cpu",
decoding_lib: str = "torchaudio",
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
format: str | None = None,
frame_rate: int | None = None,
@@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform):
):
self.data_dir = Path(data_dir)
self.device = device
self.decoding_lib = decoding_lib
self.format = format
self.frame_rate = frame_rate
self.width = width
@@ -153,66 +196,17 @@ class DecodeVideoTransform(Transform):
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
num_contiguous_frames = len(timestamps)
filter_desc = []
video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames,
"buffer_chunk_size": num_contiguous_frames,
}
# choice of decoder
if self.device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
if self.decoding_lib == "torchaudio":
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "ffmpegio":
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "decord":
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
else:
video_stream_kwgs["decoder"] = "h264"
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
raise ValueError(self.decoding_lib)
# resize
resize_width = self.width is not None
resize_height = self.height is not None
if resize_width or resize_height:
if self.device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
else:
scales = []
if resize_width:
scales.append(f"width={self.width}")
if resize_height:
scales.append(f"height={self.height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if self.format is not None:
if self.device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
raise NotImplementedError()
# filter_desc = f"scale=format={self.format}"
# filter_desc = f"scale_cuda=format={self.format}"
# filter_desc = f"scale_npp=format={self.format}"
else:
filter_desc.append(f"format=pix_fmts={self.format}")
# choice of frame rate
if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
# s = StreamReader(str(video_path).replace('.mp4','.mkv'))
s = StreamReader(str(video_path))
s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
assert frames.ndim == 4
assert frames.shape[1] == 3
b_frames.append(frames)
@@ -222,3 +216,95 @@ class DecodeVideoTransform(Transform):
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
return td
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
filter_desc = []
video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames,
"buffer_chunk_size": num_contiguous_frames,
}
# choice of decoder
if self.device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
else:
video_stream_kwgs["decoder"] = "h264"
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
# resize
resize_width = self.width is not None
resize_height = self.height is not None
if resize_width or resize_height:
if self.device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
else:
scales = []
if resize_width:
scales.append(f"width={self.width}")
if resize_height:
scales.append(f"height={self.height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if self.format is not None:
if self.device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
raise NotImplementedError()
# filter_desc = f"scale=format={self.format}"
# filter_desc = f"scale_cuda=format={self.format}"
# filter_desc = f"scale_npp=format={self.format}"
else:
filter_desc.append(f"format=pix_fmts={self.format}")
# choice of frame rate
if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}")
filter_desc.append("scale=in_range=limited:out_range=full")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
s = StreamReader(str(video_path))
s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
if "yuv" in self.format:
frames = yuv_to_rgb(frames)
return frames
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
import ffmpegio
fs, frames = ffmpegio.video.read(
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
)
frames = torch.from_numpy(frames)
frames = einops.rearrange(frames, "b h w c -> b c h w")
if self.device == "cuda":
frames = frames.to(self.device)
return frames
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
from decord import VideoReader, cpu, gpu
with open(str(video_path), "rb") as f:
ctx = gpu if self.device == "cuda" else cpu
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
# frame_id = frame_ids[0].item()
# frames = vr.get_batch([frame_id])
# frames = torch.from_numpy(frames.asnumpy())
# frames = einops.rearrange(frames, "b h w c -> b c h w")
# return frames

View File

@@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps):
"-crf",
"0", # Lossless option
"-pix_fmt",
"yuv420p", # Specify pixel format
# "yuv420p", # Specify pixel format
"yuv444p", # Specify pixel format
video_path,
# video_path.replace(".mp4", ".mkv")
]

51
test.py
View File

@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
NUM_ACTION_CHANNELS = 12
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
def count_frames(video_path):
try:
# Construct the ffprobe command to get the number of frames
@@ -272,7 +232,7 @@ if __name__ == "__main__":
if "cuvid" in k:
print(f" - {k}")
def create_replay_buffer(device):
def create_replay_buffer(device, format=None):
data_dir = Path("tmp/2024_03_17_data_video/pusht")
num_slices = 1
@@ -293,6 +253,7 @@ if __name__ == "__main__":
data_dir=data_dir,
device=device,
frame_rate=None,
format=format,
in_keys=[("observation", "frame")],
out_keys=[("observation", "frame", "data")],
),
@@ -324,8 +285,8 @@ if __name__ == "__main__":
print(time.monotonic() - start)
def test_plot(seed=1337):
rb_cuda = create_replay_buffer(device="cuda")
rb_cpu = create_replay_buffer(device="cuda")
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
n_rows = 2 # len(replay_buffer)
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
@@ -337,7 +298,7 @@ if __name__ == "__main__":
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
frames = batch_cpu["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True)
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][0].imshow(frames[0])
@@ -348,7 +309,7 @@ if __name__ == "__main__":
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
frames = batch_cuda["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True)
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1
axes[i][1].imshow(frames[0])