This commit is contained in:
Cadene
2024-03-19 13:41:49 +00:00
parent a346469a5a
commit 9cdc24bc0e
3 changed files with 152 additions and 104 deletions

View File

@@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Sequence from typing import Sequence
import einops
import torch import torch
from tensordict import TensorDictBase from tensordict import TensorDictBase
from tensordict.nn import dispatch from tensordict.nn import dispatch
@@ -9,6 +10,46 @@ from torchaudio.io import StreamReader
from torchrl.envs.transforms import Transform from torchrl.envs.transforms import Transform
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
class ViewSliceHorizonTransform(Transform): class ViewSliceHorizonTransform(Transform):
invertible = False invertible = False
@@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform):
self, self,
data_dir: Path | str, data_dir: Path | str,
device="cpu", device="cpu",
decoding_lib: str = "torchaudio",
# format options are None=yuv420p (usually), rgb24, bgr24, etc. # format options are None=yuv420p (usually), rgb24, bgr24, etc.
format: str | None = None, format: str | None = None,
frame_rate: int | None = None, frame_rate: int | None = None,
@@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform):
): ):
self.data_dir = Path(data_dir) self.data_dir = Path(data_dir)
self.device = device self.device = device
self.decoding_lib = decoding_lib
self.format = format self.format = format
self.frame_rate = frame_rate self.frame_rate = frame_rate
self.width = width self.width = width
@@ -153,66 +196,17 @@ class DecodeVideoTransform(Transform):
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item() first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
num_contiguous_frames = len(timestamps) num_contiguous_frames = len(timestamps)
filter_desc = [] if self.decoding_lib == "torchaudio":
video_stream_kwgs = { frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
"frames_per_chunk": num_contiguous_frames, elif self.decoding_lib == "ffmpegio":
"buffer_chunk_size": num_contiguous_frames, frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
} elif self.decoding_lib == "decord":
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
# choice of decoder
if self.device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
else: else:
video_stream_kwgs["decoder"] = "h264" raise ValueError(self.decoding_lib)
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
# resize assert frames.ndim == 4
resize_width = self.width is not None assert frames.shape[1] == 3
resize_height = self.height is not None
if resize_width or resize_height:
if self.device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
else:
scales = []
if resize_width:
scales.append(f"width={self.width}")
if resize_height:
scales.append(f"height={self.height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if self.format is not None:
if self.device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
raise NotImplementedError()
# filter_desc = f"scale=format={self.format}"
# filter_desc = f"scale_cuda=format={self.format}"
# filter_desc = f"scale_npp=format={self.format}"
else:
filter_desc.append(f"format=pix_fmts={self.format}")
# choice of frame rate
if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
# s = StreamReader(str(video_path).replace('.mp4','.mkv'))
s = StreamReader(str(video_path))
s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
b_frames.append(frames) b_frames.append(frames)
@@ -222,3 +216,95 @@ class DecodeVideoTransform(Transform):
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu # make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda" assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
return td return td
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
filter_desc = []
video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames,
"buffer_chunk_size": num_contiguous_frames,
}
# choice of decoder
if self.device == "cuda":
video_stream_kwgs["hw_accel"] = "cuda"
video_stream_kwgs["decoder"] = "h264_cuvid"
# video_stream_kwgs["decoder"] = "hevc_cuvid"
# video_stream_kwgs["decoder"] = "av1_cuvid"
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
else:
video_stream_kwgs["decoder"] = "h264"
# video_stream_kwgs["decoder"] = "hevc"
# video_stream_kwgs["decoder"] = "av1"
# video_stream_kwgs["decoder"] = "ffv1"
# resize
resize_width = self.width is not None
resize_height = self.height is not None
if resize_width or resize_height:
if self.device == "cuda":
assert resize_width and resize_height
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
else:
scales = []
if resize_width:
scales.append(f"width={self.width}")
if resize_height:
scales.append(f"height={self.height}")
filter_desc.append(f"scale={':'.join(scales)}")
# choice of format
if self.format is not None:
if self.device == "cuda":
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
raise NotImplementedError()
# filter_desc = f"scale=format={self.format}"
# filter_desc = f"scale_cuda=format={self.format}"
# filter_desc = f"scale_npp=format={self.format}"
else:
filter_desc.append(f"format=pix_fmts={self.format}")
# choice of frame rate
if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}")
filter_desc.append("scale=in_range=limited:out_range=full")
if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it
s = StreamReader(str(video_path))
s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs)
s.fill_buffer()
(frames,) = s.pop_chunks()
if "yuv" in self.format:
frames = yuv_to_rgb(frames)
return frames
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
import ffmpegio
fs, frames = ffmpegio.video.read(
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
)
frames = torch.from_numpy(frames)
frames = einops.rearrange(frames, "b h w c -> b c h w")
if self.device == "cuda":
frames = frames.to(self.device)
return frames
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
from decord import VideoReader, cpu, gpu
with open(str(video_path), "rb") as f:
ctx = gpu if self.device == "cuda" else cpu
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
# frame_id = frame_ids[0].item()
# frames = vr.get_batch([frame_id])
# frames = torch.from_numpy(frames.asnumpy())
# frames = einops.rearrange(frames, "b h w c -> b c h w")
# return frames

View File

@@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps):
"-crf", "-crf",
"0", # Lossless option "0", # Lossless option
"-pix_fmt", "-pix_fmt",
"yuv420p", # Specify pixel format # "yuv420p", # Specify pixel format
"yuv444p", # Specify pixel format
video_path, video_path,
# video_path.replace(".mp4", ".mkv") # video_path.replace(".mp4", ".mkv")
] ]

51
test.py
View File

@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
NUM_ACTION_CHANNELS = 12 NUM_ACTION_CHANNELS = 12
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
def count_frames(video_path): def count_frames(video_path):
try: try:
# Construct the ffprobe command to get the number of frames # Construct the ffprobe command to get the number of frames
@@ -272,7 +232,7 @@ if __name__ == "__main__":
if "cuvid" in k: if "cuvid" in k:
print(f" - {k}") print(f" - {k}")
def create_replay_buffer(device): def create_replay_buffer(device, format=None):
data_dir = Path("tmp/2024_03_17_data_video/pusht") data_dir = Path("tmp/2024_03_17_data_video/pusht")
num_slices = 1 num_slices = 1
@@ -293,6 +253,7 @@ if __name__ == "__main__":
data_dir=data_dir, data_dir=data_dir,
device=device, device=device,
frame_rate=None, frame_rate=None,
format=format,
in_keys=[("observation", "frame")], in_keys=[("observation", "frame")],
out_keys=[("observation", "frame", "data")], out_keys=[("observation", "frame", "data")],
), ),
@@ -324,8 +285,8 @@ if __name__ == "__main__":
print(time.monotonic() - start) print(time.monotonic() - start)
def test_plot(seed=1337): def test_plot(seed=1337):
rb_cuda = create_replay_buffer(device="cuda") rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
rb_cpu = create_replay_buffer(device="cuda") rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
n_rows = 2 # len(replay_buffer) n_rows = 2 # len(replay_buffer)
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0]) fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
@@ -337,7 +298,7 @@ if __name__ == "__main__":
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist()) print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
frames = batch_cpu["observation", "frame", "data"] frames = batch_cpu["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True) frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1 assert frames.shape[0] == 1
axes[i][0].imshow(frames[0]) axes[i][0].imshow(frames[0])
@@ -348,7 +309,7 @@ if __name__ == "__main__":
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist()) print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
frames = batch_cuda["observation", "frame", "data"] frames = batch_cuda["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True) frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1 assert frames.shape[0] == 1
axes[i][1].imshow(frames[0]) axes[i][1].imshow(frames[0])