Feat: Add Batched Video Encoding for Faster Dataset Recording (#1390)

* LeRobotDataset video encoding: updated `save_episode` method and added `batch_encode_videos` method to handle video encoding based on `batch_encoding_size`, allowing for both immediate and batched encoding.

* LeRobotDataset video cleanup: Enabled individual episode cleanup and check for remaining PNG files before removing the `images` directory.

* LeRobotDataset - VideoEncodingManager: added proper handling of pending episodes (encoding, cleaning) on exit or recording failures.

* LeRobotDatasetMetadata: removed `update_video_info` to only update video info at episode index 0 encoding.

* Adjusted the `record` function to utilize the new encoding management logic.

* Removed `encode_videos` method from `LeRobotDataset` and `encode_episode_videos` outputs as they are nowhere used.

---------

Signed-off-by: Xingdong Zuo <zuoxingdong@users.noreply.github.com>
Co-authored-by: Xingdong Zuo <xingdong.zuo@navercorp.com>
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>
This commit is contained in:
Xingdong Zuo
2025-07-18 19:18:52 +09:00
committed by GitHub
parent 862a4439ea
commit e6e1f085d4
3 changed files with 172 additions and 57 deletions

View File

@@ -260,8 +260,6 @@ class LeRobotDatasetMetadata:
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
self.info["total_videos"] += len(self.video_keys) self.info["total_videos"] += len(self.video_keys)
if len(self.video_keys) > 0:
self.update_video_info()
write_info(self.info, self.root) write_info(self.info, self.root)
@@ -342,6 +340,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
force_cache_sync: bool = False, force_cache_sync: bool = False,
download_videos: bool = True, download_videos: bool = True,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1,
): ):
""" """
2 modes are available for instantiating this class, depending on 2 different use cases: 2 modes are available for instantiating this class, depending on 2 different use cases:
@@ -443,6 +442,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
True. True.
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
""" """
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
@@ -454,6 +455,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
self.video_backend = video_backend if video_backend else get_safe_default_codec() self.video_backend = video_backend if video_backend else get_safe_default_codec()
self.delta_indices = None self.delta_indices = None
self.batch_encoding_size = batch_encoding_size
self.episodes_since_last_encoding = 0
# Unused attributes # Unused attributes
self.image_writer = None self.image_writer = None
@@ -811,6 +814,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
""" """
This will save to disk the current episode in self.episode_buffer. This will save to disk the current episode in self.episode_buffer.
Video encoding is handled automatically based on batch_encoding_size:
- If batch_encoding_size == 1: Videos are encoded immediately after each episode
- If batch_encoding_size > 1: Videos are encoded in batches.
Args: Args:
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
@@ -850,14 +857,28 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_episode_table(episode_buffer, episode_index) self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features) ep_stats = compute_episode_stats(episode_buffer, self.features)
if len(self.meta.video_keys) > 0: has_video_keys = len(self.meta.video_keys) > 0
video_paths = self.encode_episode_videos(episode_index) use_batched_encoding = self.batch_encoding_size > 1
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
# `meta.save_episode` be executed after encoding the videos if has_video_keys and not use_batched_encoding:
self.encode_episode_videos(episode_index)
# `meta.save_episode` should be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
# Check if we should trigger batch encoding
if has_video_keys and use_batched_encoding:
self.episodes_since_last_encoding += 1
if self.episodes_since_last_encoding == self.batch_encoding_size:
start_ep = self.num_episodes - self.batch_encoding_size
end_ep = self.num_episodes
logging.info(
f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}"
)
self.batch_encode_videos(start_ep, end_ep)
self.episodes_since_last_encoding = 0
# Episode data index and timestamp checking
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync( check_timestamps_sync(
@@ -868,16 +889,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.tolerance_s, self.tolerance_s,
) )
video_files = list(self.root.rglob("*.mp4")) # Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
parquet_files = list(self.root.rglob("*.parquet")) parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes assert len(parquet_files) == self.num_episodes
video_files = list(self.root.rglob("*.mp4"))
# delete images assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len(
img_dir = self.root / "images" self.meta.video_keys
if img_dir.is_dir(): )
shutil.rmtree(self.root / "images")
if not episode_data: # Reset the buffer if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer() self.episode_buffer = self.create_episode_buffer()
@@ -894,6 +912,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
def clear_episode_buffer(self) -> None: def clear_episode_buffer(self) -> None:
episode_index = self.episode_buffer["episode_index"] episode_index = self.episode_buffer["episode_index"]
# Clean up image files for the current episode buffer
if self.image_writer is not None: if self.image_writer is not None:
for cam_key in self.meta.camera_keys: for cam_key in self.meta.camera_keys:
img_dir = self._get_image_file_path( img_dir = self._get_image_file_path(
@@ -930,25 +950,22 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None: if self.image_writer is not None:
self.image_writer.wait_until_done() self.image_writer.wait_until_done()
def encode_videos(self) -> None: def encode_episode_videos(self, episode_index: int) -> None:
""" """
Use ffmpeg to convert frames stored as png into mp4 videos. Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading. since video encoding with ffmpeg is already using multithreading.
"""
for ep_idx in range(self.meta.total_episodes):
self.encode_episode_videos(ep_idx)
def encode_episode_videos(self, episode_index: int) -> dict: This method handles video encoding steps:
- Video encoding via ffmpeg
- Video info updating in metadata
- Raw image cleanup
Args:
episode_index (int): Index of the episode to encode.
""" """
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
since video encoding with ffmpeg is already using multithreading.
"""
video_paths = {}
for key in self.meta.video_keys: for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key) video_path = self.root / self.meta.get_video_file_path(episode_index, key)
video_paths[key] = str(video_path)
if video_path.is_file(): if video_path.is_file():
# Skip if video is already encoded. Could be the case when resuming data recording. # Skip if video is already encoded. Could be the case when resuming data recording.
continue continue
@@ -956,8 +973,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
episode_index=episode_index, image_key=key, frame_index=0 episode_index=episode_index, image_key=key, frame_index=0
).parent ).parent
encode_video_frames(img_dir, video_path, self.fps, overwrite=True) encode_video_frames(img_dir, video_path, self.fps, overwrite=True)
shutil.rmtree(img_dir)
return video_paths # Update video info (only needed when first episode is encoded since it reads from episode 0)
if len(self.meta.video_keys) > 0 and episode_index == 0:
self.meta.update_video_info()
write_info(self.meta.info, self.meta.root) # ensure video info always written properly
def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None:
"""
Batch encode videos for multiple episodes.
Args:
start_episode: Starting episode index (inclusive)
end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode
"""
if end_episode is None:
end_episode = self.meta.total_episodes
logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}")
# Encode all episodes with cleanup enabled for individual episodes
for ep_idx in range(start_episode, end_episode):
logging.info(f"Encoding videos for episode {ep_idx}")
self.encode_episode_videos(ep_idx)
logging.info("Batch video encoding completed")
@classmethod @classmethod
def create( def create(
@@ -972,6 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_writer_processes: int = 0, image_writer_processes: int = 0,
image_writer_threads: int = 0, image_writer_threads: int = 0,
video_backend: str | None = None, video_backend: str | None = None,
batch_encoding_size: int = 1,
) -> "LeRobotDataset": ) -> "LeRobotDataset":
"""Create a LeRobot Dataset from scratch in order to record data.""" """Create a LeRobot Dataset from scratch in order to record data."""
obj = cls.__new__(cls) obj = cls.__new__(cls)
@@ -988,6 +1030,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
obj.revision = None obj.revision = None
obj.tolerance_s = tolerance_s obj.tolerance_s = tolerance_s
obj.image_writer = None obj.image_writer = None
obj.batch_encoding_size = batch_encoding_size
obj.episodes_since_last_encoding = 0
if image_writer_processes or image_writer_threads: if image_writer_processes or image_writer_threads:
obj.start_image_writer(image_writer_processes, image_writer_threads) obj.start_image_writer(image_writer_processes, image_writer_threads)

View File

@@ -16,6 +16,7 @@
import glob import glob
import importlib import importlib
import logging import logging
import shutil
import warnings import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -451,3 +452,66 @@ def get_image_pixel_channels(image: Image):
return 4 # RGBA return 4 # RGBA
else: else:
raise ValueError("Unknown format") raise ValueError("Unknown format")
class VideoEncodingManager:
"""
Context manager that ensures proper video encoding and data cleanup even if exceptions occur.
This manager handles:
- Batch encoding for any remaining episodes when recording interrupted
- Cleaning up temporary image files from interrupted episodes
- Removing empty image directories
Args:
dataset: The LeRobotDataset instance
"""
def __init__(self, dataset):
self.dataset = dataset
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Handle any remaining episodes that haven't been batch encoded
if self.dataset.episodes_since_last_encoding > 0:
if exc_type is not None:
logging.info("Exception occurred. Encoding remaining episodes before exit...")
else:
logging.info("Recording stopped. Encoding remaining episodes...")
start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding
end_ep = self.dataset.num_episodes
logging.info(
f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, "
f"from episode {start_ep} to {end_ep - 1}"
)
self.dataset.batch_encode_videos(start_ep, end_ep)
# Clean up episode images if recording was interrupted
if exc_type is not None:
interrupted_episode_index = self.dataset.num_episodes
for key in self.dataset.meta.video_keys:
img_dir = self.dataset._get_image_file_path(
episode_index=interrupted_episode_index, image_key=key, frame_index=0
).parent
if img_dir.exists():
logging.debug(
f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}"
)
shutil.rmtree(img_dir)
# Clean up any remaining images directory if it's empty
img_dir = self.dataset.root / "images"
# Check for any remaining PNG files
png_files = list(img_dir.rglob("*.png"))
if len(png_files) == 0:
# Only remove the images directory if no PNG files remain
if img_dir.exists():
shutil.rmtree(img_dir)
logging.debug("Cleaned up empty images directory")
else:
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
return False # Don't suppress the original exception

View File

@@ -73,6 +73,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
@@ -145,6 +146,9 @@ class DatasetRecordConfig:
# Too many threads might cause unstable teleoperation fps due to main thread being blocked. # Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps. # Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4 num_image_writer_threads_per_camera: int = 4
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
def __post_init__(self): def __post_init__(self):
if self.single_task is None: if self.single_task is None:
@@ -298,6 +302,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
dataset = LeRobotDataset( dataset = LeRobotDataset(
cfg.dataset.repo_id, cfg.dataset.repo_id,
root=cfg.dataset.root, root=cfg.dataset.root,
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
) )
if hasattr(robot, "cameras") and len(robot.cameras) > 0: if hasattr(robot, "cameras") and len(robot.cameras) > 0:
@@ -318,6 +323,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
use_videos=cfg.dataset.video, use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
) )
# Load pretrained policy # Load pretrained policy
@@ -329,46 +335,47 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
recorded_episodes = 0 with VideoEncodingManager(dataset):
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: recorded_episodes = 0
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
record_loop( log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop( record_loop(
robot=robot, robot=robot,
events=events, events=events,
fps=cfg.dataset.fps, fps=cfg.dataset.fps,
teleop=teleop, teleop=teleop,
control_time_s=cfg.dataset.reset_time_s, policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task, single_task=cfg.dataset.single_task,
display_data=cfg.display_data, display_data=cfg.display_data,
) )
if events["rerecord_episode"]: # Execute a few seconds without recording to give time to manually reset the environment
log_say("Re-record episode", cfg.play_sounds) # Skip reset for the last episode to be recorded
events["rerecord_episode"] = False if not events["stop_recording"] and (
events["exit_early"] = False (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
dataset.clear_episode_buffer() ):
continue log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
fps=cfg.dataset.fps,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
dataset.save_episode() if events["rerecord_episode"]:
recorded_episodes += 1 log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
log_say("Stop recording", cfg.play_sounds, blocking=True) log_say("Stop recording", cfg.play_sounds, blocking=True)