This commit is contained in:
Cadene
2024-03-19 13:41:49 +00:00
parent a346469a5a
commit 9cdc24bc0e
3 changed files with 152 additions and 104 deletions

View File

@@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Sequence from typing import Sequence
import einops
import torch import torch
from tensordict import TensorDictBase from tensordict import TensorDictBase
from tensordict.nn import dispatch from tensordict.nn import dispatch
@@ -9,6 +10,46 @@ from torchaudio.io import StreamReader
from torchrl.envs.transforms import Transform from torchrl.envs.transforms import Transform
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
class ViewSliceHorizonTransform(Transform): class ViewSliceHorizonTransform(Transform):
invertible = False invertible = False
@@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform):
self, self,
data_dir: Path | str, data_dir: Path | str,
device="cpu", device="cpu",
decoding_lib: str = "torchaudio",
# format options are None=yuv420p (usually), rgb24, bgr24, etc. # format options are None=yuv420p (usually), rgb24, bgr24, etc.
format: str | None = None, format: str | None = None,
frame_rate: int | None = None, frame_rate: int | None = None,
@@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform):
): ):
self.data_dir = Path(data_dir) self.data_dir = Path(data_dir)
self.device = device self.device = device
self.decoding_lib = decoding_lib
self.format = format self.format = format
self.frame_rate = frame_rate self.frame_rate = frame_rate
self.width = width self.width = width
@@ -153,6 +196,28 @@ class DecodeVideoTransform(Transform):
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item() first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
num_contiguous_frames = len(timestamps) num_contiguous_frames = len(timestamps)
if self.decoding_lib == "torchaudio":
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "ffmpegio":
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
elif self.decoding_lib == "decord":
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
else:
raise ValueError(self.decoding_lib)
assert frames.ndim == 4
assert frames.shape[1] == 3
b_frames.append(frames)
td[outkey] = torch.stack(b_frames)
if self.device == "cuda":
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
return td
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
filter_desc = [] filter_desc = []
video_stream_kwgs = { video_stream_kwgs = {
"frames_per_chunk": num_contiguous_frames, "frames_per_chunk": num_contiguous_frames,
@@ -202,23 +267,44 @@ class DecodeVideoTransform(Transform):
if self.frame_rate is not None: if self.frame_rate is not None:
filter_desc.append(f"fps={self.frame_rate}") filter_desc.append(f"fps={self.frame_rate}")
filter_desc.append("scale=in_range=limited:out_range=full")
if len(filter_desc) > 0: if len(filter_desc) > 0:
video_stream_kwgs["filter_desc"] = ",".join(filter_desc) video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
# create a stream and load a certain number of frame at a certain frame rate # create a stream and load a certain number of frame at a certain frame rate
# TODO(rcadene): make sure it's the most optimal way to do it # TODO(rcadene): make sure it's the most optimal way to do it
# s = StreamReader(str(video_path).replace('.mp4','.mkv'))
s = StreamReader(str(video_path)) s = StreamReader(str(video_path))
s.seek(first_frame_ts) s.seek(first_frame_ts)
s.add_video_stream(**video_stream_kwgs) s.add_video_stream(**video_stream_kwgs)
s.fill_buffer() s.fill_buffer()
(frames,) = s.pop_chunks() (frames,) = s.pop_chunks()
b_frames.append(frames) if "yuv" in self.format:
frames = yuv_to_rgb(frames)
return frames
td[outkey] = torch.stack(b_frames) def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
import ffmpegio
fs, frames = ffmpegio.video.read(
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
)
frames = torch.from_numpy(frames)
frames = einops.rearrange(frames, "b h w c -> b c h w")
if self.device == "cuda": if self.device == "cuda":
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu frames = frames.to(self.device)
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda" return frames
return td
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
from decord import VideoReader, cpu, gpu
with open(str(video_path), "rb") as f:
ctx = gpu if self.device == "cuda" else cpu
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
# frame_id = frame_ids[0].item()
# frames = vr.get_batch([frame_id])
# frames = torch.from_numpy(frames.asnumpy())
# frames = einops.rearrange(frames, "b h w c -> b c h w")
# return frames

View File

@@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps):
"-crf", "-crf",
"0", # Lossless option "0", # Lossless option
"-pix_fmt", "-pix_fmt",
"yuv420p", # Specify pixel format # "yuv420p", # Specify pixel format
"yuv444p", # Specify pixel format
video_path, video_path,
# video_path.replace(".mp4", ".mkv") # video_path.replace(".mp4", ".mkv")
] ]

51
test.py
View File

@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
NUM_ACTION_CHANNELS = 12 NUM_ACTION_CHANNELS = 12
def yuv_to_rgb(frames):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu().to(torch.float)
y = frames[..., 0, :, :]
u = frames[..., 1, :, :]
v = frames[..., 2, :, :]
y /= 255
u = u / 255 - 0.5
v = v / 255 - 0.5
r = y + 1.13983 * v
g = y + -0.39465 * u - 0.58060 * v
b = y + 2.03211 * u
rgb = torch.stack([r, g, b], 1)
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
return rgb
def yuv_to_rgb_cv2(frames, return_hwc=True):
assert frames.dtype == torch.uint8
assert frames.ndim == 4
assert frames.shape[1] == 3
frames = frames.cpu()
import cv2
frames = einops.rearrange(frames, "b c h w -> b h w c")
frames = frames.numpy()
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
frames = [torch.from_numpy(frame) for frame in frames]
frames = torch.stack(frames)
if not return_hwc:
frames = einops.rearrange(frames, "b h w c -> b c h w")
return frames
def count_frames(video_path): def count_frames(video_path):
try: try:
# Construct the ffprobe command to get the number of frames # Construct the ffprobe command to get the number of frames
@@ -272,7 +232,7 @@ if __name__ == "__main__":
if "cuvid" in k: if "cuvid" in k:
print(f" - {k}") print(f" - {k}")
def create_replay_buffer(device): def create_replay_buffer(device, format=None):
data_dir = Path("tmp/2024_03_17_data_video/pusht") data_dir = Path("tmp/2024_03_17_data_video/pusht")
num_slices = 1 num_slices = 1
@@ -293,6 +253,7 @@ if __name__ == "__main__":
data_dir=data_dir, data_dir=data_dir,
device=device, device=device,
frame_rate=None, frame_rate=None,
format=format,
in_keys=[("observation", "frame")], in_keys=[("observation", "frame")],
out_keys=[("observation", "frame", "data")], out_keys=[("observation", "frame", "data")],
), ),
@@ -324,8 +285,8 @@ if __name__ == "__main__":
print(time.monotonic() - start) print(time.monotonic() - start)
def test_plot(seed=1337): def test_plot(seed=1337):
rb_cuda = create_replay_buffer(device="cuda") rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
rb_cpu = create_replay_buffer(device="cuda") rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
n_rows = 2 # len(replay_buffer) n_rows = 2 # len(replay_buffer)
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0]) fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
@@ -337,7 +298,7 @@ if __name__ == "__main__":
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist()) print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
frames = batch_cpu["observation", "frame", "data"] frames = batch_cpu["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True) frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1 assert frames.shape[0] == 1
axes[i][0].imshow(frames[0]) axes[i][0].imshow(frames[0])
@@ -348,7 +309,7 @@ if __name__ == "__main__":
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist()) print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
frames = batch_cuda["observation", "frame", "data"] frames = batch_cuda["observation", "frame", "data"]
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w") frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
frames = yuv_to_rgb(frames, return_hwc=True) frames = einops.rearrange(frames, "bt c h w -> bt h w c")
assert frames.shape[0] == 1 assert frames.shape[0] == 1
axes[i][1].imshow(frames[0]) axes[i][1].imshow(frames[0])