Feat: Add Batched Video Encoding for Faster Dataset Recording (#1390)

* LeRobotDataset video encoding: updated `save_episode` method and added `batch_encode_videos` method to handle video encoding based on `batch_encoding_size`, allowing for both immediate and batched encoding.

* LeRobotDataset video cleanup: Enabled individual episode cleanup and check for remaining PNG files before removing the `images` directory.

* LeRobotDataset - VideoEncodingManager: added proper handling of pending episodes (encoding, cleaning) on exit or recording failures.

* LeRobotDatasetMetadata: removed `update_video_info` to only update video info at episode index 0 encoding.

* Adjusted the `record` function to utilize the new encoding management logic.

* Removed `encode_videos` method from `LeRobotDataset` and `encode_episode_videos` outputs as they are nowhere used.

---------

Signed-off-by: Xingdong Zuo <zuoxingdong@users.noreply.github.com>
Co-authored-by: Xingdong Zuo <xingdong.zuo@navercorp.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
Xingdong Zuo
2025-07-18 19:18:52 +09:00
committed by GitHub
parent 862a4439ea
commit e6e1f085d4
3 changed files with 172 additions and 57 deletions

View File

@@ -260,8 +260,6 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root)
@@ -342,6 +340,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False,
download_videos: bool = True,
video_backend: str | None = None,
batch_encoding_size: int = 1,
):
"""
2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -443,6 +442,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
"""
super().__init__()
self.repo_id = repo_id
@@ -454,6 +455,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
# Unused attributes
self.image_writer = None
@@ -811,6 +814,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
"""
This will save to disk the current episode in self.episode_buffer.
Video encoding is handled automatically based on batch_encoding_size:
- If batch_encoding_size == 1: Videos are encoded immediately after each episode
- If batch_encoding_size > 1: Videos are encoded in batches.
Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
@@ -850,14 +857,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)
if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
has_video_keys = len(self.meta.video_keys) > 0
use_batched_encoding = self.batch_encoding_size > 1
# `meta.save_episode` be executed after encoding the videos
if has_video_keys and not use_batched_encoding:
self.encode_episode_videos(episode_index)
# `meta.save_episode` should be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
# Check if we should trigger batch encoding
if has_video_keys and use_batched_encoding:
self.episodes_since_last_encoding += 1
if self.episodes_since_last_encoding == self.batch_encoding_size:
start_ep = self.num_episodes - self.batch_encoding_size
end_ep = self.num_episodes
logging.info(
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}"
)
self.batch_encode_videos(start_ep, end_ep)
self.episodes_since_last_encoding = 0
# Episode data index and timestamp checking
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
@@ -868,16 +889,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s,
)
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
# Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes
# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
video_files = list(self.root.rglob("*.mp4"))
assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len(
self.meta.video_keys
)
if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
@@ -894,6 +912,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"]
# Clean up image files for the current episode buffer
if self.image_writer is not None:
for cam_key in self.meta.camera_keys:
img_dir = self._get_image_file_path(
@@ -930,25 +950,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
def encode_videos(self) -> None:
def encode_episode_videos(self, episode_index: int) -> None:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
for ep_idx in range(self.meta.total_episodes):
self.encode_episode_videos(ep_idx)
def encode_episode_videos(self, episode_index: int) -> dict:
This method handles video encoding steps:
- Video encoding via ffmpeg
- Video info updating in metadata
- Raw image cleanup
Args:
episode_index (int): Index of the episode to encode.
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
video_paths = {}
for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
video_paths[key] = str(video_path)
if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording.
continue
@@ -956,8 +973,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index=episode_index, image_key=key, frame_index=0
).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(img_dir)
return video_paths
# Update video info (only needed when first episode is encoded since it reads from episode 0)
if len(self.meta.video_keys) > 0 and episode_index == 0:
self.meta.update_video_info()
write_info(self.meta.info, self.meta.root) # ensure video info always written properly
def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None:
"""
Batch encode videos for multiple episodes.
Args:
start_episode: Starting episode index (inclusive)
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode
"""
if end_episode is None:
end_episode = self.meta.total_episodes
logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}")
# Encode all episodes with cleanup enabled for individual episodes
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding videos for episode {ep_idx}")
self.encode_episode_videos(ep_idx)
logging.info("Batch video encoding completed")
@classmethod
def create(
@@ -972,6 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0,
image_writer_threads: int = 0,
video_backend: str | None = None,
batch_encoding_size: int = 1,
) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls)
@@ -988,6 +1030,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.revision = None
obj.tolerance_s = tolerance_s
obj.image_writer = None
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads)

View File

@@ -16,6 +16,7 @@
import glob
import importlib
import logging
import shutil
import warnings
from dataclasses import dataclass, field
from pathlib import Path
@@ -451,3 +452,66 @@ def get_image_pixel_channels(image: Image):
return 4 # RGBA
else:
raise ValueError("Unknown format")
class VideoEncodingManager:
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
"""
def __init__(self, dataset):
self.dataset = dataset
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset.batch_encode_videos(start_ep, end_ep)
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception

View File

@@ -73,6 +73,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.robots import ( # noqa: F401
@@ -145,6 +146,9 @@ class DatasetRecordConfig:
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
def __post_init__(self):
if self.single_task is None:
@@ -298,6 +302,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
@@ -318,6 +323,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
)
# Load pretrained policy
@@ -329,46 +335,47 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
listener, events = init_keyboard_listener()
recorded_episodes = 0
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
with VideoEncodingManager(dataset):
recorded_episodes = 0
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
dataset.save_episode()
recorded_episodes += 1
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
log_say("Stop recording", cfg.play_sounds, blocking=True)