WIP
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from tensordict import TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
@@ -9,6 +10,46 @@ from torchaudio.io import StreamReader
|
||||
from torchrl.envs.transforms import Transform
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
class ViewSliceHorizonTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
@@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform):
|
||||
self,
|
||||
data_dir: Path | str,
|
||||
device="cpu",
|
||||
decoding_lib: str = "torchaudio",
|
||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||
format: str | None = None,
|
||||
frame_rate: int | None = None,
|
||||
@@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform):
|
||||
):
|
||||
self.data_dir = Path(data_dir)
|
||||
self.device = device
|
||||
self.decoding_lib = decoding_lib
|
||||
self.format = format
|
||||
self.frame_rate = frame_rate
|
||||
self.width = width
|
||||
@@ -153,66 +196,17 @@ class DecodeVideoTransform(Transform):
|
||||
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
|
||||
num_contiguous_frames = len(timestamps)
|
||||
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||
if self.decoding_lib == "torchaudio":
|
||||
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "ffmpegio":
|
||||
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
|
||||
elif self.decoding_lib == "decord":
|
||||
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
# video_stream_kwgs["decoder"] = "hevc"
|
||||
# video_stream_kwgs["decoder"] = "av1"
|
||||
# video_stream_kwgs["decoder"] = "ffv1"
|
||||
raise ValueError(self.decoding_lib)
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
# s = StreamReader(str(video_path).replace('.mp4','.mkv'))
|
||||
s = StreamReader(str(video_path))
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
b_frames.append(frames)
|
||||
|
||||
@@ -222,3 +216,95 @@ class DecodeVideoTransform(Transform):
|
||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||
return td
|
||||
|
||||
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
# video_stream_kwgs["decoder"] = "hevc"
|
||||
# video_stream_kwgs["decoder"] = "av1"
|
||||
# video_stream_kwgs["decoder"] = "ffv1"
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
filter_desc.append("scale=in_range=limited:out_range=full")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
s = StreamReader(str(video_path))
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
|
||||
if "yuv" in self.format:
|
||||
frames = yuv_to_rgb(frames)
|
||||
return frames
|
||||
|
||||
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
import ffmpegio
|
||||
|
||||
fs, frames = ffmpegio.video.read(
|
||||
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
|
||||
)
|
||||
frames = torch.from_numpy(frames)
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
if self.device == "cuda":
|
||||
frames = frames.to(self.device)
|
||||
return frames
|
||||
|
||||
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||
from decord import VideoReader, cpu, gpu
|
||||
|
||||
with open(str(video_path), "rb") as f:
|
||||
ctx = gpu if self.device == "cuda" else cpu
|
||||
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
|
||||
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
|
||||
# frame_id = frame_ids[0].item()
|
||||
# frames = vr.get_batch([frame_id])
|
||||
# frames = torch.from_numpy(frames.asnumpy())
|
||||
# frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
# return frames
|
||||
|
||||
@@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps):
|
||||
"-crf",
|
||||
"0", # Lossless option
|
||||
"-pix_fmt",
|
||||
"yuv420p", # Specify pixel format
|
||||
# "yuv420p", # Specify pixel format
|
||||
"yuv444p", # Specify pixel format
|
||||
video_path,
|
||||
# video_path.replace(".mp4", ".mkv")
|
||||
]
|
||||
|
||||
51
test.py
51
test.py
@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
|
||||
NUM_ACTION_CHANNELS = 12
|
||||
|
||||
|
||||
def yuv_to_rgb(frames):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
|
||||
frames = frames.cpu().to(torch.float)
|
||||
y = frames[..., 0, :, :]
|
||||
u = frames[..., 1, :, :]
|
||||
v = frames[..., 2, :, :]
|
||||
|
||||
y /= 255
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
@@ -272,7 +232,7 @@ if __name__ == "__main__":
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device):
|
||||
def create_replay_buffer(device, format=None):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
@@ -293,6 +253,7 @@ if __name__ == "__main__":
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
format=format,
|
||||
in_keys=[("observation", "frame")],
|
||||
out_keys=[("observation", "frame", "data")],
|
||||
),
|
||||
@@ -324,8 +285,8 @@ if __name__ == "__main__":
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda")
|
||||
rb_cpu = create_replay_buffer(device="cuda")
|
||||
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
||||
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
@@ -337,7 +298,7 @@ if __name__ == "__main__":
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
@@ -348,7 +309,7 @@ if __name__ == "__main__":
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user