forked from tangger/lerobot
WIP
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
import einops
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDictBase
|
from tensordict import TensorDictBase
|
||||||
from tensordict.nn import dispatch
|
from tensordict.nn import dispatch
|
||||||
@@ -9,6 +10,46 @@ from torchaudio.io import StreamReader
|
|||||||
from torchrl.envs.transforms import Transform
|
from torchrl.envs.transforms import Transform
|
||||||
|
|
||||||
|
|
||||||
|
def yuv_to_rgb(frames):
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
assert frames.ndim == 4
|
||||||
|
assert frames.shape[1] == 3
|
||||||
|
|
||||||
|
frames = frames.cpu().to(torch.float)
|
||||||
|
y = frames[..., 0, :, :]
|
||||||
|
u = frames[..., 1, :, :]
|
||||||
|
v = frames[..., 2, :, :]
|
||||||
|
|
||||||
|
y /= 255
|
||||||
|
u = u / 255 - 0.5
|
||||||
|
v = v / 255 - 0.5
|
||||||
|
|
||||||
|
r = y + 1.13983 * v
|
||||||
|
g = y + -0.39465 * u - 0.58060 * v
|
||||||
|
b = y + 2.03211 * u
|
||||||
|
|
||||||
|
rgb = torch.stack([r, g, b], 1)
|
||||||
|
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||||
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
|
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||||
|
assert frames.dtype == torch.uint8
|
||||||
|
assert frames.ndim == 4
|
||||||
|
assert frames.shape[1] == 3
|
||||||
|
frames = frames.cpu()
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||||
|
frames = frames.numpy()
|
||||||
|
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||||
|
frames = [torch.from_numpy(frame) for frame in frames]
|
||||||
|
frames = torch.stack(frames)
|
||||||
|
if not return_hwc:
|
||||||
|
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
class ViewSliceHorizonTransform(Transform):
|
class ViewSliceHorizonTransform(Transform):
|
||||||
invertible = False
|
invertible = False
|
||||||
|
|
||||||
@@ -77,6 +118,7 @@ class DecodeVideoTransform(Transform):
|
|||||||
self,
|
self,
|
||||||
data_dir: Path | str,
|
data_dir: Path | str,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
|
decoding_lib: str = "torchaudio",
|
||||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
frame_rate: int | None = None,
|
frame_rate: int | None = None,
|
||||||
@@ -89,6 +131,7 @@ class DecodeVideoTransform(Transform):
|
|||||||
):
|
):
|
||||||
self.data_dir = Path(data_dir)
|
self.data_dir = Path(data_dir)
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.decoding_lib = decoding_lib
|
||||||
self.format = format
|
self.format = format
|
||||||
self.frame_rate = frame_rate
|
self.frame_rate = frame_rate
|
||||||
self.width = width
|
self.width = width
|
||||||
@@ -153,66 +196,17 @@ class DecodeVideoTransform(Transform):
|
|||||||
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
|
first_frame_ts = timestamps[FIRST_FRAME].squeeze(0).item()
|
||||||
num_contiguous_frames = len(timestamps)
|
num_contiguous_frames = len(timestamps)
|
||||||
|
|
||||||
filter_desc = []
|
if self.decoding_lib == "torchaudio":
|
||||||
video_stream_kwgs = {
|
frames = self._decode_frames_torchaudio(video_path, first_frame_ts, num_contiguous_frames)
|
||||||
"frames_per_chunk": num_contiguous_frames,
|
elif self.decoding_lib == "ffmpegio":
|
||||||
"buffer_chunk_size": num_contiguous_frames,
|
frames = self._decode_frames_ffmpegio(video_path, first_frame_ts, num_contiguous_frames)
|
||||||
}
|
elif self.decoding_lib == "decord":
|
||||||
|
frames = self._decode_frames_decord(video_path, first_frame_ts, num_contiguous_frames)
|
||||||
# choice of decoder
|
|
||||||
if self.device == "cuda":
|
|
||||||
video_stream_kwgs["hw_accel"] = "cuda"
|
|
||||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
|
||||||
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
|
||||||
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
|
||||||
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
|
||||||
else:
|
else:
|
||||||
video_stream_kwgs["decoder"] = "h264"
|
raise ValueError(self.decoding_lib)
|
||||||
# video_stream_kwgs["decoder"] = "hevc"
|
|
||||||
# video_stream_kwgs["decoder"] = "av1"
|
|
||||||
# video_stream_kwgs["decoder"] = "ffv1"
|
|
||||||
|
|
||||||
# resize
|
assert frames.ndim == 4
|
||||||
resize_width = self.width is not None
|
assert frames.shape[1] == 3
|
||||||
resize_height = self.height is not None
|
|
||||||
if resize_width or resize_height:
|
|
||||||
if self.device == "cuda":
|
|
||||||
assert resize_width and resize_height
|
|
||||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
|
||||||
else:
|
|
||||||
scales = []
|
|
||||||
if resize_width:
|
|
||||||
scales.append(f"width={self.width}")
|
|
||||||
if resize_height:
|
|
||||||
scales.append(f"height={self.height}")
|
|
||||||
filter_desc.append(f"scale={':'.join(scales)}")
|
|
||||||
|
|
||||||
# choice of format
|
|
||||||
if self.format is not None:
|
|
||||||
if self.device == "cuda":
|
|
||||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
|
||||||
raise NotImplementedError()
|
|
||||||
# filter_desc = f"scale=format={self.format}"
|
|
||||||
# filter_desc = f"scale_cuda=format={self.format}"
|
|
||||||
# filter_desc = f"scale_npp=format={self.format}"
|
|
||||||
else:
|
|
||||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
|
||||||
|
|
||||||
# choice of frame rate
|
|
||||||
if self.frame_rate is not None:
|
|
||||||
filter_desc.append(f"fps={self.frame_rate}")
|
|
||||||
|
|
||||||
if len(filter_desc) > 0:
|
|
||||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
|
||||||
|
|
||||||
# create a stream and load a certain number of frame at a certain frame rate
|
|
||||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
|
||||||
# s = StreamReader(str(video_path).replace('.mp4','.mkv'))
|
|
||||||
s = StreamReader(str(video_path))
|
|
||||||
s.seek(first_frame_ts)
|
|
||||||
s.add_video_stream(**video_stream_kwgs)
|
|
||||||
s.fill_buffer()
|
|
||||||
(frames,) = s.pop_chunks()
|
|
||||||
|
|
||||||
b_frames.append(frames)
|
b_frames.append(frames)
|
||||||
|
|
||||||
@@ -222,3 +216,95 @@ class DecodeVideoTransform(Transform):
|
|||||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||||
return td
|
return td
|
||||||
|
|
||||||
|
def _decode_frames_torchaudio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||||
|
filter_desc = []
|
||||||
|
video_stream_kwgs = {
|
||||||
|
"frames_per_chunk": num_contiguous_frames,
|
||||||
|
"buffer_chunk_size": num_contiguous_frames,
|
||||||
|
}
|
||||||
|
|
||||||
|
# choice of decoder
|
||||||
|
if self.device == "cuda":
|
||||||
|
video_stream_kwgs["hw_accel"] = "cuda"
|
||||||
|
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "hevc_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "av1_cuvid"
|
||||||
|
# video_stream_kwgs["decoder"] = "ffv1_cuvid"
|
||||||
|
else:
|
||||||
|
video_stream_kwgs["decoder"] = "h264"
|
||||||
|
# video_stream_kwgs["decoder"] = "hevc"
|
||||||
|
# video_stream_kwgs["decoder"] = "av1"
|
||||||
|
# video_stream_kwgs["decoder"] = "ffv1"
|
||||||
|
|
||||||
|
# resize
|
||||||
|
resize_width = self.width is not None
|
||||||
|
resize_height = self.height is not None
|
||||||
|
if resize_width or resize_height:
|
||||||
|
if self.device == "cuda":
|
||||||
|
assert resize_width and resize_height
|
||||||
|
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||||
|
else:
|
||||||
|
scales = []
|
||||||
|
if resize_width:
|
||||||
|
scales.append(f"width={self.width}")
|
||||||
|
if resize_height:
|
||||||
|
scales.append(f"height={self.height}")
|
||||||
|
filter_desc.append(f"scale={':'.join(scales)}")
|
||||||
|
|
||||||
|
# choice of format
|
||||||
|
if self.format is not None:
|
||||||
|
if self.device == "cuda":
|
||||||
|
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||||
|
raise NotImplementedError()
|
||||||
|
# filter_desc = f"scale=format={self.format}"
|
||||||
|
# filter_desc = f"scale_cuda=format={self.format}"
|
||||||
|
# filter_desc = f"scale_npp=format={self.format}"
|
||||||
|
else:
|
||||||
|
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||||
|
|
||||||
|
# choice of frame rate
|
||||||
|
if self.frame_rate is not None:
|
||||||
|
filter_desc.append(f"fps={self.frame_rate}")
|
||||||
|
|
||||||
|
filter_desc.append("scale=in_range=limited:out_range=full")
|
||||||
|
|
||||||
|
if len(filter_desc) > 0:
|
||||||
|
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||||
|
|
||||||
|
# create a stream and load a certain number of frame at a certain frame rate
|
||||||
|
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||||
|
s = StreamReader(str(video_path))
|
||||||
|
s.seek(first_frame_ts)
|
||||||
|
s.add_video_stream(**video_stream_kwgs)
|
||||||
|
s.fill_buffer()
|
||||||
|
(frames,) = s.pop_chunks()
|
||||||
|
|
||||||
|
if "yuv" in self.format:
|
||||||
|
frames = yuv_to_rgb(frames)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def _decode_frames_ffmpegio(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||||
|
import ffmpegio
|
||||||
|
|
||||||
|
fs, frames = ffmpegio.video.read(
|
||||||
|
str(video_path), ss=str(first_frame_ts), vframes=num_contiguous_frames, pix_fmt=self.format
|
||||||
|
)
|
||||||
|
frames = torch.from_numpy(frames)
|
||||||
|
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
if self.device == "cuda":
|
||||||
|
frames = frames.to(self.device)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
def _decode_frames_decord(self, video_path, first_frame_ts, num_contiguous_frames):
|
||||||
|
from decord import VideoReader, cpu, gpu
|
||||||
|
|
||||||
|
with open(str(video_path), "rb") as f:
|
||||||
|
ctx = gpu if self.device == "cuda" else cpu
|
||||||
|
vr = VideoReader(f, ctx=ctx(0)) # noqa: F841
|
||||||
|
raise NotImplementedError("Convert `first_frame_ts` into frame_id")
|
||||||
|
# frame_id = frame_ids[0].item()
|
||||||
|
# frames = vr.get_batch([frame_id])
|
||||||
|
# frames = torch.from_numpy(frames.asnumpy())
|
||||||
|
# frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||||
|
# return frames
|
||||||
|
|||||||
@@ -61,7 +61,8 @@ def cat_and_write_video(video_path, frames, fps):
|
|||||||
"-crf",
|
"-crf",
|
||||||
"0", # Lossless option
|
"0", # Lossless option
|
||||||
"-pix_fmt",
|
"-pix_fmt",
|
||||||
"yuv420p", # Specify pixel format
|
# "yuv420p", # Specify pixel format
|
||||||
|
"yuv444p", # Specify pixel format
|
||||||
video_path,
|
video_path,
|
||||||
# video_path.replace(".mp4", ".mkv")
|
# video_path.replace(".mp4", ".mkv")
|
||||||
]
|
]
|
||||||
|
|||||||
51
test.py
51
test.py
@@ -25,46 +25,6 @@ NUM_STATE_CHANNELS = 12
|
|||||||
NUM_ACTION_CHANNELS = 12
|
NUM_ACTION_CHANNELS = 12
|
||||||
|
|
||||||
|
|
||||||
def yuv_to_rgb(frames):
|
|
||||||
assert frames.dtype == torch.uint8
|
|
||||||
assert frames.ndim == 4
|
|
||||||
assert frames.shape[1] == 3
|
|
||||||
|
|
||||||
frames = frames.cpu().to(torch.float)
|
|
||||||
y = frames[..., 0, :, :]
|
|
||||||
u = frames[..., 1, :, :]
|
|
||||||
v = frames[..., 2, :, :]
|
|
||||||
|
|
||||||
y /= 255
|
|
||||||
u = u / 255 - 0.5
|
|
||||||
v = v / 255 - 0.5
|
|
||||||
|
|
||||||
r = y + 1.13983 * v
|
|
||||||
g = y + -0.39465 * u - 0.58060 * v
|
|
||||||
b = y + 2.03211 * u
|
|
||||||
|
|
||||||
rgb = torch.stack([r, g, b], 1)
|
|
||||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
|
||||||
return rgb
|
|
||||||
|
|
||||||
|
|
||||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
|
||||||
assert frames.dtype == torch.uint8
|
|
||||||
assert frames.ndim == 4
|
|
||||||
assert frames.shape[1] == 3
|
|
||||||
frames = frames.cpu()
|
|
||||||
import cv2
|
|
||||||
|
|
||||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
|
||||||
frames = frames.numpy()
|
|
||||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
|
||||||
frames = [torch.from_numpy(frame) for frame in frames]
|
|
||||||
frames = torch.stack(frames)
|
|
||||||
if not return_hwc:
|
|
||||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
|
||||||
return frames
|
|
||||||
|
|
||||||
|
|
||||||
def count_frames(video_path):
|
def count_frames(video_path):
|
||||||
try:
|
try:
|
||||||
# Construct the ffprobe command to get the number of frames
|
# Construct the ffprobe command to get the number of frames
|
||||||
@@ -272,7 +232,7 @@ if __name__ == "__main__":
|
|||||||
if "cuvid" in k:
|
if "cuvid" in k:
|
||||||
print(f" - {k}")
|
print(f" - {k}")
|
||||||
|
|
||||||
def create_replay_buffer(device):
|
def create_replay_buffer(device, format=None):
|
||||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||||
|
|
||||||
num_slices = 1
|
num_slices = 1
|
||||||
@@ -293,6 +253,7 @@ if __name__ == "__main__":
|
|||||||
data_dir=data_dir,
|
data_dir=data_dir,
|
||||||
device=device,
|
device=device,
|
||||||
frame_rate=None,
|
frame_rate=None,
|
||||||
|
format=format,
|
||||||
in_keys=[("observation", "frame")],
|
in_keys=[("observation", "frame")],
|
||||||
out_keys=[("observation", "frame", "data")],
|
out_keys=[("observation", "frame", "data")],
|
||||||
),
|
),
|
||||||
@@ -324,8 +285,8 @@ if __name__ == "__main__":
|
|||||||
print(time.monotonic() - start)
|
print(time.monotonic() - start)
|
||||||
|
|
||||||
def test_plot(seed=1337):
|
def test_plot(seed=1337):
|
||||||
rb_cuda = create_replay_buffer(device="cuda")
|
rb_cuda = create_replay_buffer(device="cuda", format="yuv444p")
|
||||||
rb_cpu = create_replay_buffer(device="cuda")
|
rb_cpu = create_replay_buffer(device="cpu", format="yuv444p")
|
||||||
|
|
||||||
n_rows = 2 # len(replay_buffer)
|
n_rows = 2 # len(replay_buffer)
|
||||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||||
@@ -337,7 +298,7 @@ if __name__ == "__main__":
|
|||||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||||
frames = batch_cpu["observation", "frame", "data"]
|
frames = batch_cpu["observation", "frame", "data"]
|
||||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||||
assert frames.shape[0] == 1
|
assert frames.shape[0] == 1
|
||||||
axes[i][0].imshow(frames[0])
|
axes[i][0].imshow(frames[0])
|
||||||
|
|
||||||
@@ -348,7 +309,7 @@ if __name__ == "__main__":
|
|||||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||||
frames = batch_cuda["observation", "frame", "data"]
|
frames = batch_cuda["observation", "frame", "data"]
|
||||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||||
assert frames.shape[0] == 1
|
assert frames.shape[0] == 1
|
||||||
axes[i][1].imshow(frames[0])
|
axes[i][1].imshow(frames[0])
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user