forked from tangger/lerobot
Add video decoding in dataset (WIP: issue with gray background)
This commit is contained in:
296
test.py
296
test.py
@@ -4,24 +4,21 @@
|
||||
import subprocess
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchrl
|
||||
from matplotlib import pyplot as plt
|
||||
from tensordict import TensorDict, TensorDictBase
|
||||
from tensordict.nn import dispatch
|
||||
from tensordict.utils import NestedKey
|
||||
from torchaudio.io import StreamReader
|
||||
from tensordict import TensorDict
|
||||
from torchaudio.utils import ffmpeg_utils
|
||||
from torchrl.data.replay_buffers.replay_buffers import TensorDictReplayBuffer
|
||||
from torchrl.data.replay_buffers.samplers import SliceSampler, SliceSamplerWithoutReplacement
|
||||
from torchrl.data.replay_buffers.storages import TensorStorage, _collate_id
|
||||
from torchrl.data.replay_buffers.writers import ImmutableDatasetWriter, Writer
|
||||
from torchrl.envs.transforms import Transform
|
||||
from torchrl.envs.transforms.transforms import Compose
|
||||
|
||||
from lerobot.common.datasets.transforms import DecodeVideoTransform, KeepFrames, ViewSliceHorizonTransform
|
||||
from lerobot.common.utils import set_seed
|
||||
|
||||
NUM_STATE_CHANNELS = 12
|
||||
@@ -42,15 +39,32 @@ def yuv_to_rgb(frames):
|
||||
u = u / 255 - 0.5
|
||||
v = v / 255 - 0.5
|
||||
|
||||
r = y + 1.14 * v
|
||||
g = y + -0.396 * u - 0.581 * v
|
||||
b = y + 2.029 * u
|
||||
r = y + 1.13983 * v
|
||||
g = y + -0.39465 * u - 0.58060 * v
|
||||
b = y + 2.03211 * u
|
||||
|
||||
rgb = torch.stack([r, g, b], 1)
|
||||
rgb = (rgb * 255).clamp(0, 255).to(torch.uint8)
|
||||
return rgb
|
||||
|
||||
|
||||
def yuv_to_rgb_cv2(frames, return_hwc=True):
|
||||
assert frames.dtype == torch.uint8
|
||||
assert frames.ndim == 4
|
||||
assert frames.shape[1] == 3
|
||||
frames = frames.cpu()
|
||||
import cv2
|
||||
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c")
|
||||
frames = frames.numpy()
|
||||
frames = [cv2.cvtColor(frame, cv2.COLOR_YUV2RGB) for frame in frames]
|
||||
frames = [torch.from_numpy(frame) for frame in frames]
|
||||
frames = torch.stack(frames)
|
||||
if not return_hwc:
|
||||
frames = einops.rearrange(frames, "b h w c -> b c h w")
|
||||
return frames
|
||||
|
||||
|
||||
def count_frames(video_path):
|
||||
try:
|
||||
# Construct the ffprobe command to get the number of frames
|
||||
@@ -131,211 +145,6 @@ def get_frame_timestamps(frame_rate, num_frames):
|
||||
# return td
|
||||
|
||||
|
||||
class ViewSliceHorizonTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(self, num_slices, horizon):
|
||||
super().__init__()
|
||||
self.num_slices = num_slices
|
||||
self.horizon = horizon
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
td = td.view(self.num_slices, self.horizon)
|
||||
return td
|
||||
|
||||
|
||||
class KeepFrames(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
positions,
|
||||
in_keys: Sequence[NestedKey],
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
):
|
||||
if isinstance(positions, list):
|
||||
assert isinstance(positions[0], int)
|
||||
# TODO(rcadene)L add support for `isinstance(positions, int)`?
|
||||
|
||||
self.positions = positions
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
super().__init__(in_keys=in_keys, out_keys=out_keys)
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
# we need set batch_size=[] before assigning a different shape to td[outkey]
|
||||
td.batch_size = []
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
td[outkey] = td[inkey][:, self.positions]
|
||||
return td
|
||||
|
||||
|
||||
class DecodeVideoTransform(Transform):
|
||||
invertible = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device="cpu",
|
||||
# format options are None=yuv420p (usually), rgb24, bgr24, etc.
|
||||
format: str | None = None,
|
||||
frame_rate: int | None = None,
|
||||
width: int | None = None,
|
||||
height: int | None = None,
|
||||
in_keys: Sequence[NestedKey] = None,
|
||||
out_keys: Sequence[NestedKey] = None,
|
||||
in_keys_inv: Sequence[NestedKey] | None = None,
|
||||
out_keys_inv: Sequence[NestedKey] | None = None,
|
||||
):
|
||||
self.device = device
|
||||
self.format = format
|
||||
self.frame_rate = frame_rate
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.video_id_to_path = None
|
||||
if out_keys is None:
|
||||
out_keys = in_keys
|
||||
if in_keys_inv is None:
|
||||
in_keys_inv = out_keys
|
||||
if out_keys_inv is None:
|
||||
out_keys_inv = in_keys
|
||||
super().__init__(
|
||||
in_keys=in_keys, out_keys=out_keys, in_keys_inv=in_keys_inv, out_keys_inv=out_keys_inv
|
||||
)
|
||||
|
||||
def set_video_id_to_path(self, video_id_to_path):
|
||||
self.video_id_to_path = video_id_to_path
|
||||
|
||||
def _reset(self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase) -> TensorDictBase:
|
||||
# _reset is called once when the environment reset to normalize the first observation
|
||||
tensordict_reset = self._call(tensordict_reset)
|
||||
return tensordict_reset
|
||||
|
||||
@dispatch(source="in_keys", dest="out_keys")
|
||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
|
||||
return self._call(tensordict)
|
||||
|
||||
def _call(self, td: TensorDictBase) -> TensorDictBase:
|
||||
assert (
|
||||
self.video_id_to_path is not None
|
||||
), "Setting a video_id_to_path dictionary with `self.set_video_id_to_path(video_id_to_path)` is required."
|
||||
|
||||
for inkey, outkey in zip(self.in_keys, self.out_keys, strict=False):
|
||||
# TODO(rcadene): don't know how to do `inkey not in td`
|
||||
if td.get(inkey, None) is None:
|
||||
continue
|
||||
|
||||
bsize = len(td[inkey]) # num episodes in the batch
|
||||
b_frames = []
|
||||
for i in range(bsize):
|
||||
assert (
|
||||
td["observation", "frame", "video_id"].ndim == 2
|
||||
), "We expect 2 dims. Respectively, number of episodes in the batch and number of observations"
|
||||
|
||||
ep_video_ids = td[inkey]["video_id"][i]
|
||||
timestamps = td[inkey]["timestamp"][i]
|
||||
frame_ids = td["frame_id"][i]
|
||||
|
||||
unique_video_id = (ep_video_ids.min() == ep_video_ids.max()).item()
|
||||
assert unique_video_id
|
||||
|
||||
is_ascending = torch.all(timestamps[:-1] <= timestamps[1:]).item()
|
||||
assert is_ascending
|
||||
|
||||
is_contiguous = ((frame_ids[1:] - frame_ids[:-1]) == 1).all().item()
|
||||
assert is_contiguous
|
||||
|
||||
FIRST_FRAME = 0 # noqa: N806
|
||||
video_id = ep_video_ids[FIRST_FRAME].item()
|
||||
video_path = self.video_id_to_path[video_id]
|
||||
first_frame_ts = timestamps[FIRST_FRAME].item()
|
||||
num_contiguous_frames = len(timestamps)
|
||||
|
||||
filter_desc = []
|
||||
video_stream_kwgs = {
|
||||
"frames_per_chunk": num_contiguous_frames,
|
||||
"buffer_chunk_size": num_contiguous_frames,
|
||||
}
|
||||
|
||||
# choice of decoder
|
||||
if self.device == "cuda":
|
||||
video_stream_kwgs["hw_accel"] = "cuda"
|
||||
video_stream_kwgs["decoder"] = "h264_cuvid"
|
||||
else:
|
||||
video_stream_kwgs["decoder"] = "h264"
|
||||
|
||||
# resize
|
||||
resize_width = self.width is not None
|
||||
resize_height = self.height is not None
|
||||
if resize_width or resize_height:
|
||||
if self.device == "cuda":
|
||||
assert resize_width and resize_height
|
||||
video_stream_kwgs["decoder_option"] = {"resize": f"{self.width}x{self.height}"}
|
||||
else:
|
||||
scales = []
|
||||
if resize_width:
|
||||
scales.append(f"width={self.width}")
|
||||
if resize_height:
|
||||
scales.append(f"height={self.height}")
|
||||
filter_desc.append(f"scale={':'.join(scales)}")
|
||||
|
||||
# choice of format
|
||||
if self.format is not None:
|
||||
if self.device == "cuda":
|
||||
# TODO(rcadene): rebuild ffmpeg with --enable-cuda-nvcc, --enable-cuvid, and --enable-libnpp
|
||||
raise NotImplementedError()
|
||||
# filter_desc = f"scale=format={self.format}"
|
||||
# filter_desc = f"scale_cuda=format={self.format}"
|
||||
# filter_desc = f"scale_npp=format={self.format}"
|
||||
else:
|
||||
filter_desc.append(f"format=pix_fmts={self.format}")
|
||||
|
||||
# choice of frame rate
|
||||
if self.frame_rate is not None:
|
||||
filter_desc.append(f"fps={self.frame_rate}")
|
||||
|
||||
if len(filter_desc) > 0:
|
||||
video_stream_kwgs["filter_desc"] = ",".join(filter_desc)
|
||||
|
||||
# create a stream and load a certain number of frame at a certain frame rate
|
||||
# TODO(rcadene): make sure it's the most optimal way to do it
|
||||
s = StreamReader(video_path)
|
||||
s.seek(first_frame_ts)
|
||||
s.add_video_stream(**video_stream_kwgs)
|
||||
s.fill_buffer()
|
||||
(frames,) = s.pop_chunks()
|
||||
|
||||
b_frames.append(frames)
|
||||
|
||||
td[outkey] = torch.stack(b_frames)
|
||||
|
||||
if self.device == "cuda":
|
||||
# make sure we return a cuda tensor, since the frames can be unwillingly sent to cpu
|
||||
assert "cuda" in str(td[outkey].device), f"{td[outkey].device} instead of cuda"
|
||||
return td
|
||||
|
||||
|
||||
class VideoExperienceReplay(TensorDictReplayBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -349,7 +158,7 @@ class VideoExperienceReplay(TensorDictReplayBuffer):
|
||||
writer: Writer = None,
|
||||
transform: "torchrl.envs.Transform" = None,
|
||||
):
|
||||
self.data_dir = root / "2024_03_17_test_dataset"
|
||||
self.data_dir = root
|
||||
self.rb_dir = self.data_dir / "replay_buffer"
|
||||
|
||||
storage, meta_data = self._load_or_download()
|
||||
@@ -454,7 +263,18 @@ if __name__ == "__main__":
|
||||
|
||||
import tqdm
|
||||
|
||||
print("FFmpeg Library versions:")
|
||||
for k, ver in ffmpeg_utils.get_versions().items():
|
||||
print(f" {k}:\t{'.'.join(str(v) for v in ver)}")
|
||||
|
||||
print("Available NVDEC Decoders:")
|
||||
for k in ffmpeg_utils.get_video_decoders().keys(): # noqa: SIM118
|
||||
if "cuvid" in k:
|
||||
print(f" - {k}")
|
||||
|
||||
def create_replay_buffer(device):
|
||||
data_dir = Path("tmp/2024_03_17_data_video/pusht")
|
||||
|
||||
num_slices = 1
|
||||
horizon = 2
|
||||
batch_size = num_slices * horizon
|
||||
@@ -470,6 +290,7 @@ if __name__ == "__main__":
|
||||
ViewSliceHorizonTransform(num_slices, horizon),
|
||||
KeepFrames(positions=[0], in_keys=[("observation")]),
|
||||
DecodeVideoTransform(
|
||||
data_dir=data_dir,
|
||||
device=device,
|
||||
frame_rate=None,
|
||||
in_keys=[("observation", "frame")],
|
||||
@@ -478,7 +299,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
replay_buffer = VideoExperienceReplay(
|
||||
root=Path("tmp"),
|
||||
root=data_dir,
|
||||
batch_size=batch_size,
|
||||
# prefetch=4,
|
||||
transform=Compose(*transforms),
|
||||
@@ -489,52 +310,57 @@ if __name__ == "__main__":
|
||||
def test_time():
|
||||
replay_buffer = create_replay_buffer(device="cuda")
|
||||
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(2)):
|
||||
# include_info=False is required to not have a batch_size mismatch error with the truncated key (2,8) != (16, 1)
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
print(time.monotonic() - start)
|
||||
|
||||
start = time.time()
|
||||
start = time.monotonic()
|
||||
for _ in tqdm.tqdm(range(10)):
|
||||
replay_buffer.sample(include_info=False)
|
||||
torch.cuda.synchronize()
|
||||
print(time.time() - start)
|
||||
print(time.monotonic() - start)
|
||||
|
||||
def test_plot():
|
||||
def test_plot(seed=1337):
|
||||
rb_cuda = create_replay_buffer(device="cuda")
|
||||
rb_cpu = create_replay_buffer(device="cpu")
|
||||
rb_cpu = create_replay_buffer(device="cuda")
|
||||
|
||||
n_rows = 2 # len(replay_buffer)
|
||||
fig, axes = plt.subplots(n_rows, 2, figsize=[12.8, 16.0])
|
||||
fig, axes = plt.subplots(n_rows, 3, figsize=[12.8, 16.0])
|
||||
for i in range(n_rows):
|
||||
set_seed(1337 + i)
|
||||
set_seed(seed + i)
|
||||
batch_cpu = rb_cpu.sample(include_info=False)
|
||||
print(batch_cpu["frame_id"])
|
||||
print("frame_ids cpu", batch_cpu["frame_id"].tolist())
|
||||
print("episode cpu", batch_cpu["episode"].tolist())
|
||||
print("timestamps cpu", batch_cpu["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cpu["observation", "frame", "data"]
|
||||
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][0].imshow(frames[0])
|
||||
|
||||
set_seed(1337 + i)
|
||||
set_seed(seed + i)
|
||||
batch_cuda = rb_cuda.sample(include_info=False)
|
||||
print(batch_cuda["frame_id"])
|
||||
print("frame_ids cuda", batch_cuda["frame_id"].tolist())
|
||||
print("episode cuda", batch_cuda["episode"].tolist())
|
||||
print("timestamps cuda", batch_cuda["observation", "frame", "timestamp"].tolist())
|
||||
frames = batch_cuda["observation", "frame", "data"]
|
||||
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = yuv_to_rgb(frames)
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
|
||||
frames = yuv_to_rgb(frames, return_hwc=True)
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][1].imshow(frames[0])
|
||||
|
||||
frames = batch_cuda["observation", "image"].type(torch.uint8)
|
||||
frames = einops.rearrange(frames, "b t c h w -> (b t) c h w")
|
||||
frames = einops.rearrange(frames, "bt c h w -> bt h w c")
|
||||
assert frames.shape[0] == 1
|
||||
axes[i][2].imshow(frames[0])
|
||||
|
||||
axes[0][0].set_title("Software decoder")
|
||||
axes[0][1].set_title("HW decoder")
|
||||
axes[0][2].set_title("uint8")
|
||||
plt.setp(axes, xticks=[], yticks=[])
|
||||
plt.tight_layout()
|
||||
fig.savefig(rb_cuda.data_dir / "test.png", dpi=300)
|
||||
|
||||
Reference in New Issue
Block a user