From e6e1f085d4f66f34e9e6bd7c7f00893e9e425f9e Mon Sep 17 00:00:00 2001 From: Xingdong Zuo Date: Fri, 18 Jul 2025 19:18:52 +0900 Subject: [PATCH] Feat: Add Batched Video Encoding for Faster Dataset Recording (#1390) * LeRobotDataset video encoding: updated `save_episode` method and added `batch_encode_videos` method to handle video encoding based on `batch_encoding_size`, allowing for both immediate and batched encoding. * LeRobotDataset video cleanup: Enabled individual episode cleanup and check for remaining PNG files before removing the `images` directory. * LeRobotDataset - VideoEncodingManager: added proper handling of pending episodes (encoding, cleaning) on exit or recording failures. * LeRobotDatasetMetadata: removed `update_video_info` to only update video info at episode index 0 encoding. * Adjusted the `record` function to utilize the new encoding management logic. * Removed `encode_videos` method from `LeRobotDataset` and `encode_episode_videos` outputs as they are nowhere used. --------- Signed-off-by: Xingdong Zuo Co-authored-by: Xingdong Zuo Co-authored-by: Caroline Pascal --- src/lerobot/datasets/lerobot_dataset.py | 98 ++++++++++++++++++------- src/lerobot/datasets/video_utils.py | 64 ++++++++++++++++ src/lerobot/record.py | 67 +++++++++-------- 3 files changed, 172 insertions(+), 57 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 46feed2b..72d1a722 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -260,8 +260,6 @@ class LeRobotDatasetMetadata: self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"} self.info["total_videos"] += len(self.video_keys) - if len(self.video_keys) > 0: - self.update_video_info() write_info(self.info, self.root) @@ -342,6 +340,7 @@ class LeRobotDataset(torch.utils.data.Dataset): force_cache_sync: bool = False, download_videos: bool = True, video_backend: str | None = None, + batch_encoding_size: int = 1, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -443,6 +442,8 @@ class LeRobotDataset(torch.utils.data.Dataset): True. video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'. You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. + batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. + Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. """ super().__init__() self.repo_id = repo_id @@ -454,6 +455,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.revision = revision if revision else CODEBASE_VERSION self.video_backend = video_backend if video_backend else get_safe_default_codec() self.delta_indices = None + self.batch_encoding_size = batch_encoding_size + self.episodes_since_last_encoding = 0 # Unused attributes self.image_writer = None @@ -811,6 +814,10 @@ class LeRobotDataset(torch.utils.data.Dataset): """ This will save to disk the current episode in self.episode_buffer. + Video encoding is handled automatically based on batch_encoding_size: + - If batch_encoding_size == 1: Videos are encoded immediately after each episode + - If batch_encoding_size > 1: Videos are encoded in batches. + Args: episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to @@ -850,14 +857,28 @@ class LeRobotDataset(torch.utils.data.Dataset): self._save_episode_table(episode_buffer, episode_index) ep_stats = compute_episode_stats(episode_buffer, self.features) - if len(self.meta.video_keys) > 0: - video_paths = self.encode_episode_videos(episode_index) - for key in self.meta.video_keys: - episode_buffer[key] = video_paths[key] + has_video_keys = len(self.meta.video_keys) > 0 + use_batched_encoding = self.batch_encoding_size > 1 - # `meta.save_episode` be executed after encoding the videos + if has_video_keys and not use_batched_encoding: + self.encode_episode_videos(episode_index) + + # `meta.save_episode` should be executed after encoding the videos self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats) + # Check if we should trigger batch encoding + if has_video_keys and use_batched_encoding: + self.episodes_since_last_encoding += 1 + if self.episodes_since_last_encoding == self.batch_encoding_size: + start_ep = self.num_episodes - self.batch_encoding_size + end_ep = self.num_episodes + logging.info( + f"Batch encoding {self.batch_encoding_size} videos for episodes {start_ep} to {end_ep - 1}" + ) + self.batch_encode_videos(start_ep, end_ep) + self.episodes_since_last_encoding = 0 + + # Episode data index and timestamp checking ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index]) ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()} check_timestamps_sync( @@ -868,16 +889,13 @@ class LeRobotDataset(torch.utils.data.Dataset): self.tolerance_s, ) - video_files = list(self.root.rglob("*.mp4")) - assert len(video_files) == self.num_episodes * len(self.meta.video_keys) - + # Verify that we have one parquet file per episode and the number of video files matches the number of encoded episodes parquet_files = list(self.root.rglob("*.parquet")) assert len(parquet_files) == self.num_episodes - - # delete images - img_dir = self.root / "images" - if img_dir.is_dir(): - shutil.rmtree(self.root / "images") + video_files = list(self.root.rglob("*.mp4")) + assert len(video_files) == (self.num_episodes - self.episodes_since_last_encoding) * len( + self.meta.video_keys + ) if not episode_data: # Reset the buffer self.episode_buffer = self.create_episode_buffer() @@ -894,6 +912,8 @@ class LeRobotDataset(torch.utils.data.Dataset): def clear_episode_buffer(self) -> None: episode_index = self.episode_buffer["episode_index"] + + # Clean up image files for the current episode buffer if self.image_writer is not None: for cam_key in self.meta.camera_keys: img_dir = self._get_image_file_path( @@ -930,25 +950,22 @@ class LeRobotDataset(torch.utils.data.Dataset): if self.image_writer is not None: self.image_writer.wait_until_done() - def encode_videos(self) -> None: + def encode_episode_videos(self, episode_index: int) -> None: """ Use ffmpeg to convert frames stored as png into mp4 videos. Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. - """ - for ep_idx in range(self.meta.total_episodes): - self.encode_episode_videos(ep_idx) - def encode_episode_videos(self, episode_index: int) -> dict: + This method handles video encoding steps: + - Video encoding via ffmpeg + - Video info updating in metadata + - Raw image cleanup + + Args: + episode_index (int): Index of the episode to encode. """ - Use ffmpeg to convert frames stored as png into mp4 videos. - Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, - since video encoding with ffmpeg is already using multithreading. - """ - video_paths = {} for key in self.meta.video_keys: video_path = self.root / self.meta.get_video_file_path(episode_index, key) - video_paths[key] = str(video_path) if video_path.is_file(): # Skip if video is already encoded. Could be the case when resuming data recording. continue @@ -956,8 +973,32 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index=episode_index, image_key=key, frame_index=0 ).parent encode_video_frames(img_dir, video_path, self.fps, overwrite=True) + shutil.rmtree(img_dir) - return video_paths + # Update video info (only needed when first episode is encoded since it reads from episode 0) + if len(self.meta.video_keys) > 0 and episode_index == 0: + self.meta.update_video_info() + write_info(self.meta.info, self.meta.root) # ensure video info always written properly + + def batch_encode_videos(self, start_episode: int = 0, end_episode: int | None = None) -> None: + """ + Batch encode videos for multiple episodes. + + Args: + start_episode: Starting episode index (inclusive) + end_episode: Ending episode index (exclusive). If None, encodes all episodes from start_episode + """ + if end_episode is None: + end_episode = self.meta.total_episodes + + logging.info(f"Starting batch video encoding for episodes {start_episode} to {end_episode - 1}") + + # Encode all episodes with cleanup enabled for individual episodes + for ep_idx in range(start_episode, end_episode): + logging.info(f"Encoding videos for episode {ep_idx}") + self.encode_episode_videos(ep_idx) + + logging.info("Batch video encoding completed") @classmethod def create( @@ -972,6 +1013,7 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_processes: int = 0, image_writer_threads: int = 0, video_backend: str | None = None, + batch_encoding_size: int = 1, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" obj = cls.__new__(cls) @@ -988,6 +1030,8 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.revision = None obj.tolerance_s = tolerance_s obj.image_writer = None + obj.batch_encoding_size = batch_encoding_size + obj.episodes_since_last_encoding = 0 if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 3a77f36e..b05edf6b 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -16,6 +16,7 @@ import glob import importlib import logging +import shutil import warnings from dataclasses import dataclass, field from pathlib import Path @@ -451,3 +452,66 @@ def get_image_pixel_channels(image: Image): return 4 # RGBA else: raise ValueError("Unknown format") + + +class VideoEncodingManager: + """ + Context manager that ensures proper video encoding and data cleanup even if exceptions occur. + + This manager handles: + - Batch encoding for any remaining episodes when recording interrupted + - Cleaning up temporary image files from interrupted episodes + - Removing empty image directories + + Args: + dataset: The LeRobotDataset instance + """ + + def __init__(self, dataset): + self.dataset = dataset + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Handle any remaining episodes that haven't been batch encoded + if self.dataset.episodes_since_last_encoding > 0: + if exc_type is not None: + logging.info("Exception occurred. Encoding remaining episodes before exit...") + else: + logging.info("Recording stopped. Encoding remaining episodes...") + + start_ep = self.dataset.num_episodes - self.dataset.episodes_since_last_encoding + end_ep = self.dataset.num_episodes + logging.info( + f"Encoding remaining {self.dataset.episodes_since_last_encoding} episodes, " + f"from episode {start_ep} to {end_ep - 1}" + ) + self.dataset.batch_encode_videos(start_ep, end_ep) + + # Clean up episode images if recording was interrupted + if exc_type is not None: + interrupted_episode_index = self.dataset.num_episodes + for key in self.dataset.meta.video_keys: + img_dir = self.dataset._get_image_file_path( + episode_index=interrupted_episode_index, image_key=key, frame_index=0 + ).parent + if img_dir.exists(): + logging.debug( + f"Cleaning up interrupted episode images for episode {interrupted_episode_index}, camera {key}" + ) + shutil.rmtree(img_dir) + + # Clean up any remaining images directory if it's empty + img_dir = self.dataset.root / "images" + # Check for any remaining PNG files + png_files = list(img_dir.rglob("*.png")) + if len(png_files) == 0: + # Only remove the images directory if no PNG files remain + if img_dir.exists(): + shutil.rmtree(img_dir) + logging.debug("Cleaned up empty images directory") + else: + logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + + return False # Don't suppress the original exception diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 0b1af192..d662efca 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -73,6 +73,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.robots import ( # noqa: F401 @@ -145,6 +146,9 @@ class DatasetRecordConfig: # Too many threads might cause unstable teleoperation fps due to main thread being blocked. # Not enough threads might cause low camera fps. num_image_writer_threads_per_camera: int = 4 + # Number of episodes to record before batch encoding videos + # Set to 1 for immediate encoding (default behavior), or higher for batched encoding + video_encoding_batch_size: int = 1 def __post_init__(self): if self.single_task is None: @@ -298,6 +302,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: dataset = LeRobotDataset( cfg.dataset.repo_id, root=cfg.dataset.root, + batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -318,6 +323,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: use_videos=cfg.dataset.video, image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + batch_encoding_size=cfg.dataset.video_encoding_batch_size, ) # Load pretrained policy @@ -329,46 +335,47 @@ def record(cfg: RecordConfig) -> LeRobotDataset: listener, events = init_keyboard_listener() - recorded_episodes = 0 - while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - record_loop( - robot=robot, - events=events, - fps=cfg.dataset.fps, - teleop=teleop, - policy=policy, - dataset=dataset, - control_time_s=cfg.dataset.episode_time_s, - single_task=cfg.dataset.single_task, - display_data=cfg.display_data, - ) - - # Execute a few seconds without recording to give time to manually reset the environment - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) + with VideoEncodingManager(dataset): + recorded_episodes = 0 + while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) record_loop( robot=robot, events=events, fps=cfg.dataset.fps, teleop=teleop, - control_time_s=cfg.dataset.reset_time_s, + policy=policy, + dataset=dataset, + control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, ) - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Execute a few seconds without recording to give time to manually reset the environment + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + fps=cfg.dataset.fps, + teleop=teleop, + control_time_s=cfg.dataset.reset_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 log_say("Stop recording", cfg.play_sounds, blocking=True)