forked from tangger/lerobot
Add video decoding in dataset (WIP: issue with gray background)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
@@ -26,9 +29,48 @@ def visualize_dataset_cli(cfg: dict):
|
||||
|
||||
def cat_and_write_video(video_path, frames, fps):
|
||||
frames = torch.cat(frames)
|
||||
assert frames.dtype == torch.uint8
|
||||
if frames.dtype != torch.uint8:
|
||||
logging.warning(f"frames are expected to be uint8 to {frames.dtype}")
|
||||
frames = frames.type(torch.uint8)
|
||||
|
||||
_, _, h, w = frames.shape
|
||||
frames = einops.rearrange(frames, "b c h w -> b h w c").numpy()
|
||||
imageio.mimsave(video_path, frames, fps=fps)
|
||||
|
||||
img_dir = Path(video_path.split(".")[0])
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i in range(len(frames)):
|
||||
imageio.imwrite(str(img_dir / f"frame_{i:04d}.png"), frames[i])
|
||||
|
||||
ffmpeg_command = [
|
||||
"ffmpeg",
|
||||
"-r",
|
||||
str(fps),
|
||||
"-f",
|
||||
"image2",
|
||||
"-s",
|
||||
f"{w}x{h}",
|
||||
"-i",
|
||||
str(img_dir / "frame_%04d.png"),
|
||||
"-vcodec",
|
||||
"libx264",
|
||||
#'-vcodec', 'libx265',
|
||||
#'-vcodec', 'libaom-av1',
|
||||
"-crf",
|
||||
"0", # Lossless option
|
||||
"-pix_fmt",
|
||||
"yuv420p", # Specify pixel format
|
||||
video_path,
|
||||
# video_path.replace(".mp4", ".mkv")
|
||||
]
|
||||
subprocess.run(ffmpeg_command, check=True)
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# clean temporary image directory
|
||||
# shutil.rmtree(img_dir)
|
||||
|
||||
|
||||
def visualize_dataset(cfg: dict, out_dir=None):
|
||||
@@ -61,7 +103,10 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
|
||||
# TODO(rcaene): modify offline_buffer._sampler._sample_list or sampler to randomly sample an episode, but sequentially sample frames
|
||||
no_more_frames = offline_buffer._sampler._sample_list.numel() == 0
|
||||
new_episode = ep_idx != current_ep_idx
|
||||
new_episode = ep_idx > current_ep_idx
|
||||
|
||||
if ep_idx < current_ep_idx:
|
||||
break
|
||||
|
||||
if new_episode:
|
||||
logging.info(f"Visualizing episode {current_ep_idx}")
|
||||
@@ -71,7 +116,7 @@ def visualize_dataset(cfg: dict, out_dir=None):
|
||||
# append last observed frames (the ones after last action taken)
|
||||
frames[im_key].append(ep_td[("next", *im_key)])
|
||||
|
||||
video_dir = Path(out_dir) / "visualize_dataset"
|
||||
video_dir = Path(out_dir) / "videos"
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(offline_buffer.image_keys) > 1:
|
||||
|
||||
Reference in New Issue
Block a user