From b8f7e401d42a17d1ac90355f39a1ee7171afb58f Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 10 Oct 2025 12:32:07 +0200 Subject: [PATCH] Dataset tools (#2100) * feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma --------- Co-authored-by: CarolinePascal Co-authored-by: Jack Vial Co-authored-by: Steven Palma --- docs/source/_toctree.yml | 2 + docs/source/using_dataset_tools.mdx | 102 ++ examples/dataset/use_dataset_tools.py | 117 +++ pyproject.toml | 3 +- src/lerobot/datasets/aggregate.py | 82 +- src/lerobot/datasets/dataset_tools.py | 1004 +++++++++++++++++++ src/lerobot/datasets/lerobot_dataset.py | 14 +- src/lerobot/datasets/utils.py | 9 +- src/lerobot/datasets/video_utils.py | 3 + src/lerobot/scripts/lerobot_edit_dataset.py | 286 ++++++ src/lerobot/utils/utils.py | 20 + tests/datasets/test_aggregate.py | 90 ++ tests/datasets/test_dataset_tools.py | 891 ++++++++++++++++ 13 files changed, 2593 insertions(+), 30 deletions(-) create mode 100644 docs/source/using_dataset_tools.mdx create mode 100644 examples/dataset/use_dataset_tools.py create mode 100644 src/lerobot/datasets/dataset_tools.py create mode 100644 src/lerobot/scripts/lerobot_edit_dataset.py create mode 100644 tests/datasets/test_dataset_tools.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 3b6cccc95..568bd6380 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -25,6 +25,8 @@ title: Using LeRobotDataset - local: porting_datasets_v3 title: Porting Large Datasets + - local: using_dataset_tools + title: Using the Dataset Tools title: "Datasets" - sections: - local: act diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx new file mode 100644 index 000000000..affca0ee5 --- /dev/null +++ b/docs/source/using_dataset_tools.mdx @@ -0,0 +1,102 @@ +# Using Dataset Tools + +This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets. + +## Overview + +LeRobot provides several utilities for manipulating datasets: + +1. **Delete Episodes** - Remove specific episodes from a dataset +2. **Split Dataset** - Divide a dataset into multiple smaller datasets +3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids` +4. **Add Features** - Add new features to a dataset +5. **Remove Features** - Remove features from a dataset + +The core implementation is in `lerobot.datasets.dataset_tools`. +An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. + +## Command-Line Tool: lerobot-edit-dataset + +`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features. + +Run `lerobot-edit-dataset --help` for more information on the configuration of each operation. + +### Usage Examples + +#### Delete Episodes + +Remove specific episodes from a dataset. This is useful for filtering out undesired data. + +```bash +# Delete episodes 0, 2, and 5 (modifies original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +# Delete episodes and save to a new dataset (preserves original dataset) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" +``` + +#### Split Dataset + +Divide a dataset into multiple subsets. + +```bash +# Split by fractions (e.g. 80% train, 20% test, 20% val) +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}' + +# Split by specific episode indices +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}' +``` + +There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`. + +#### Merge Datasets + +Combine multiple datasets into a single dataset. + +```bash +# Merge train and validation splits back into one dataset +lerobot-edit-dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" +``` + +#### Remove Features + +Remove features from a dataset. + +```bash +# Remove a camera feature +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" +``` + +### Push to Hub + +Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: + +```bash +lerobot-edit-dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_after_deletion \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" \ + --push_to_hub +``` + +There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`. diff --git a/examples/dataset/use_dataset_tools.py b/examples/dataset/use_dataset_tools.py new file mode 100644 index 000000000..244259872 --- /dev/null +++ b/examples/dataset/use_dataset_tools.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script demonstrating dataset tools utilities. + +This script shows how to: +1. Delete episodes from a dataset +2. Split a dataset into train/val sets +3. Add/remove features +4. Merge datasets + +Usage: + python examples/dataset/use_dataset_tools.py +""" + +import numpy as np + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def main(): + dataset = LeRobotDataset("lerobot/pusht") + + print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames") + print(f"Features: {list(dataset.meta.features.keys())}") + + print("\n1. Deleting episodes 0 and 2...") + filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered") + print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes") + + print("\n2. Splitting dataset into train/val...") + splits = split_dataset( + dataset, + splits={"train": 0.8, "val": 0.2}, + ) + print(f"Train split: {splits['train'].meta.total_episodes} episodes") + print(f"Val split: {splits['val'].meta.total_episodes} episodes") + + print("\n3. Adding a reward feature...") + + reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32) + dataset_with_reward = add_feature( + dataset, + feature_name="reward", + feature_values=reward_values, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward", + ) + + def compute_success(row_dict, episode_index, frame_index): + episode_length = 10 + return float(frame_index >= episode_length - 10) + + dataset_with_success = add_feature( + dataset_with_reward, + feature_name="success", + feature_values=compute_success, + feature_info={ + "dtype": "float32", + "shape": (1,), + "names": None, + }, + repo_id="lerobot/pusht_with_reward_and_success", + ) + + print(f"New features: {list(dataset_with_success.meta.features.keys())}") + + print("\n4. Removing the success feature...") + dataset_cleaned = remove_feature( + dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned" + ) + print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}") + + print("\n5. Merging train and val splits back together...") + merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged") + print(f"Merged dataset: {merged.meta.total_episodes} episodes") + + print("\n6. Complex workflow example...") + + if len(dataset.meta.camera_keys) > 1: + camera_to_remove = dataset.meta.camera_keys[0] + print(f"Removing camera: {camera_to_remove}") + dataset_no_cam = remove_feature( + dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera" + ) + print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}") + + print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c67b481f0..a70208cb2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ dependencies = [ "cmake>=3.29.0.1,<4.2.0", "einops>=0.8.0,<0.9.0", "opencv-python-headless>=4.9.0,<4.13.0", - "av>=14.2.0,<16.0.0", + "av>=15.0.0,<16.0.0", "jsonlines>=4.0.0,<5.0.0", "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", @@ -175,6 +175,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" +lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 803645f29..e7ea59ed0 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -39,7 +39,7 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.datasets.video_utils import concatenate_video_files +from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]): @@ -130,10 +130,34 @@ def update_meta_data( df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): - df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"] - df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"] - df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] - df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + # Store original video file indices before updating + orig_chunk_col = f"videos/{key}/chunk_index" + orig_file_col = f"videos/{key}/file_index" + df["_orig_chunk"] = df[orig_chunk_col].copy() + df["_orig_file"] = df[orig_file_col].copy() + + # Update chunk and file indices to point to destination + df[orig_chunk_col] = video_idx["chunk"] + df[orig_file_col] = video_idx["file"] + + # Apply per-source-file timestamp offsets + src_to_offset = video_idx.get("src_to_offset", {}) + if src_to_offset: + # Apply offset based on original source file + for idx in df.index: + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) + offset = src_to_offset.get(src_key, 0) + df.at[idx, f"videos/{key}/from_timestamp"] += offset + df.at[idx, f"videos/{key}/to_timestamp"] += offset + else: + # Fallback to simple offset (for backward compatibility) + df[f"videos/{key}/from_timestamp"] = ( + df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"] + ) + df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"] + + # Clean up temporary columns + df = df.drop(columns=["_orig_chunk", "_orig_file"]) df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"] df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"] @@ -193,6 +217,9 @@ def aggregate_datasets( robot_type=robot_type, features=features, root=aggr_root, + chunks_size=chunk_size, + data_files_size_in_mb=data_files_size_in_mb, + video_files_size_in_mb=video_files_size_in_mb, ) logging.info("Find all tasks") @@ -236,6 +263,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu Returns: dict: Updated videos_idx with current chunk and file indices. """ + for key in videos_idx: + videos_idx[key]["episode_duration"] = 0 + # Track offset for each source (chunk, file) pair + videos_idx[key]["src_to_offset"] = {} + for key, video_idx in videos_idx.items(): unique_chunk_file_pairs = { (chunk, file) @@ -249,6 +281,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu chunk_idx = video_idx["chunk"] file_idx = video_idx["file"] + current_offset = video_idx["latest_duration"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( @@ -263,21 +296,24 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu file_index=file_idx, ) - # If a new file is created, we don't want to increment the latest_duration - update_latest_duration = False + src_duration = get_video_duration_in_s(src_path) if not dst_path.exists(): - # First write to this destination file + # Store offset before incrementing + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) - continue # not accumulating further, already copied the file in place + videos_idx[key]["episode_duration"] += src_duration + current_offset += src_duration + continue - # Check file sizes before appending src_size = get_video_size_in_mb(src_path) dst_size = get_video_size_in_mb(dst_path) if dst_size + src_size >= video_files_size_in_mb: - # Rotate to a new chunk/file + # Rotate to a new file, this source becomes start of new destination + # So its offset should be 0 + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0 chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size) dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, @@ -286,25 +322,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu ) dst_path.parent.mkdir(parents=True, exist_ok=True) shutil.copy(str(src_path), str(dst_path)) + # Reset offset for next file + current_offset = src_duration else: - # Get the timestamps shift for this video - timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"] - - # Append to existing video file + # Append to existing video file - use current accumulated offset + videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset concatenate_video_files( [dst_path, src_path], dst_path, ) - # Update the latest_duration when appending (shifts timestamps!) - update_latest_duration = not update_latest_duration + current_offset += src_duration + + videos_idx[key]["episode_duration"] += src_duration - # Update the videos_idx with the final chunk and file indices for this key videos_idx[key]["chunk"] = chunk_idx videos_idx[key]["file"] = file_idx - if update_latest_duration: - videos_idx[key]["latest_duration"] += timestamps_shift_s - return videos_idx @@ -389,9 +422,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - for k in videos_idx: - videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] - meta_idx = append_or_create_parquet_file( df, src_path, @@ -403,6 +433,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): aggr_root=dst_meta.root, ) + # Increment latest_duration by the total duration added from this source dataset + for k in videos_idx: + videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"] + return meta_idx diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py new file mode 100644 index 000000000..fdeb24a72 --- /dev/null +++ b/src/lerobot/datasets/dataset_tools.py @@ -0,0 +1,1004 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset tools utilities for LeRobotDataset. + +This module provides utilities for: +- Deleting episodes from datasets +- Splitting datasets into multiple smaller datasets +- Adding/removing features from datasets +- Merging datasets (wrapper around aggregate functionality) +""" + +import logging +import shutil +from collections.abc import Callable +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from lerobot.datasets.aggregate import aggregate_datasets +from lerobot.datasets.compute_stats import aggregate_stats +from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.datasets.utils import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_DATA_FILE_SIZE_IN_MB, + DEFAULT_DATA_PATH, + DEFAULT_EPISODES_PATH, + get_parquet_file_size_in_mb, + to_parquet_with_hf_images, + update_chunk_file_indices, + write_info, + write_stats, + write_tasks, +) +from lerobot.utils.constants import HF_LEROBOT_HOME + + +def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: + """Load a single episode's metadata including stats from parquet file. + + Args: + src_dataset: Source dataset + episode_idx: Episode index to load + + Returns: + dict containing episode metadata and stats + """ + ep_meta = src_dataset.meta.episodes[episode_idx] + chunk_idx = ep_meta["meta/episodes/chunk_index"] + file_idx = ep_meta["meta/episodes/file_index"] + + parquet_path = src_dataset.root / DEFAULT_EPISODES_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + df = pd.read_parquet(parquet_path) + + episode_row = df[df["episode_index"] == episode_idx].iloc[0] + + return episode_row.to_dict() + + +def delete_episodes( + dataset: LeRobotDataset, + episode_indices: list[int], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Delete episodes from a LeRobotDataset and create a new dataset. + + Args: + dataset: The source LeRobotDataset. + episode_indices: List of episode indices to delete. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if not episode_indices: + raise ValueError("No episodes to delete") + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_indices) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + logging.info(f"Deleting {len(episode_indices)} episodes from dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + episodes_to_keep = [i for i in range(dataset.meta.total_episodes) if i not in episode_indices] + if not episodes_to_keep: + raise ValueError("Cannot delete all episodes from dataset") + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(episodes_to_keep)} + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + logging.info(f"Created new dataset with {len(episodes_to_keep)} episodes") + return new_dataset + + +def split_dataset( + dataset: LeRobotDataset, + splits: dict[str, float | list[int]], + output_dir: str | Path | None = None, +) -> dict[str, LeRobotDataset]: + """Split a LeRobotDataset into multiple smaller datasets. + + Args: + dataset: The source LeRobotDataset to split. + splits: Either a dict mapping split names to episode indices, or a dict mapping + split names to fractions (must sum to <= 1.0). + output_dir: Base directory for output datasets. If None, uses default location. + + Examples: + Split by specific episodes + splits = {"train": [0, 1, 2], "val": [3, 4]} + datasets = split_dataset(dataset, splits) + + Split by fractions + splits = {"train": 0.8, "val": 0.2} + datasets = split_dataset(dataset, splits) + """ + if not splits: + raise ValueError("No splits provided") + + if all(isinstance(v, float) for v in splits.values()): + splits = _fractions_to_episode_indices(dataset.meta.total_episodes, splits) + + all_episodes = set() + for split_name, episodes in splits.items(): + if not episodes: + raise ValueError(f"Split '{split_name}' has no episodes") + episode_set = set(episodes) + if episode_set & all_episodes: + raise ValueError("Episodes cannot appear in multiple splits") + all_episodes.update(episode_set) + + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = all_episodes - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + if output_dir is not None: + output_dir = Path(output_dir) + + result_datasets = {} + + for split_name, episodes in splits.items(): + logging.info(f"Creating split '{split_name}' with {len(episodes)} episodes") + + split_repo_id = f"{dataset.repo_id}_{split_name}" + + split_output_dir = ( + output_dir / split_name if output_dir is not None else HF_LEROBOT_HOME / split_repo_id + ) + + episode_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(sorted(episodes))} + + new_meta = LeRobotDatasetMetadata.create( + repo_id=split_repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=split_output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + video_metadata = None + if dataset.meta.video_keys: + video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping) + + data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping) + + _copy_and_reindex_episodes_metadata(dataset, new_meta, episode_mapping, data_metadata, video_metadata) + + new_dataset = LeRobotDataset( + repo_id=split_repo_id, + root=split_output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + result_datasets[split_name] = new_dataset + + return result_datasets + + +def merge_datasets( + datasets: list[LeRobotDataset], + output_repo_id: str, + output_dir: str | Path | None = None, +) -> LeRobotDataset: + """Merge multiple LeRobotDatasets into a single dataset. + + This is a wrapper around the aggregate_datasets functionality with a cleaner API. + + Args: + datasets: List of LeRobotDatasets to merge. + output_repo_id: Repository ID for the merged dataset. + output_dir: Directory to save the merged dataset. If None, uses default location. + """ + if not datasets: + raise ValueError("No datasets to merge") + + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / output_repo_id + + repo_ids = [ds.repo_id for ds in datasets] + roots = [ds.root for ds in datasets] + + aggregate_datasets( + repo_ids=repo_ids, + aggr_repo_id=output_repo_id, + roots=roots, + aggr_root=output_dir, + ) + + merged_dataset = LeRobotDataset( + repo_id=output_repo_id, + root=output_dir, + image_transforms=datasets[0].image_transforms, + delta_timestamps=datasets[0].delta_timestamps, + tolerance_s=datasets[0].tolerance_s, + ) + + return merged_dataset + + +def add_feature( + dataset: LeRobotDataset, + feature_name: str, + feature_values: np.ndarray | torch.Tensor | Callable, + feature_info: dict, + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Add a new feature to a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_name: Name of the new feature. + feature_values: Either: + - Array/tensor of shape (num_frames, ...) with values for each frame + - Callable that takes (frame_dict, episode_index, frame_index) and returns feature value + feature_info: Dictionary with feature metadata (dtype, shape, names). + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + """ + if feature_name in dataset.meta.features: + raise ValueError(f"Feature '{feature_name}' already exists in dataset") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + required_keys = {"dtype", "shape"} + if not required_keys.issubset(feature_info.keys()): + raise ValueError(f"feature_info must contain keys: {required_keys}") + + new_features = dataset.meta.features.copy() + new_features[feature_name] = feature_info + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + add_features={feature_name: (feature_values, feature_info)}, + ) + + if dataset.meta.video_keys: + _copy_videos(dataset, new_meta) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def remove_feature( + dataset: LeRobotDataset, + feature_names: str | list[str], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Remove features from a LeRobotDataset. + + Args: + dataset: The source LeRobotDataset. + feature_names: Name(s) of features to remove. Can be a single string or list. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_modified" to original. + + """ + if isinstance(feature_names, str): + feature_names = [feature_names] + + for name in feature_names: + if name not in dataset.meta.features: + raise ValueError(f"Feature '{name}' not found in dataset") + + required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"} + if any(name in required_features for name in feature_names): + raise ValueError(f"Cannot remove required features: {required_features}") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_modified" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names} + + video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys] + + remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove] + + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(remaining_video_keys) > 0, + ) + + _copy_data_with_feature_changes( + dataset=dataset, + new_meta=new_meta, + remove_features=feature_names, + ) + + if new_meta.video_keys: + _copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove) + + new_dataset = LeRobotDataset( + repo_id=repo_id, + root=output_dir, + image_transforms=dataset.image_transforms, + delta_timestamps=dataset.delta_timestamps, + tolerance_s=dataset.tolerance_s, + ) + + return new_dataset + + +def _fractions_to_episode_indices( + total_episodes: int, + splits: dict[str, float], +) -> dict[str, list[int]]: + """Convert split fractions to episode indices.""" + if sum(splits.values()) > 1.0: + raise ValueError("Split fractions must sum to <= 1.0") + + indices = list(range(total_episodes)) + result = {} + start_idx = 0 + + for split_name, fraction in splits.items(): + num_episodes = int(total_episodes * fraction) + if num_episodes == 0: + logging.warning(f"Split '{split_name}' has no episodes, skipping...") + continue + end_idx = start_idx + num_episodes + if split_name == list(splits.keys())[-1]: + end_idx = total_episodes + result[split_name] = indices[start_idx:end_idx] + start_idx = end_idx + + return result + + +def _copy_and_reindex_data( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], +) -> dict[int, dict]: + """Copy and filter data files, only modifying files with deleted episodes. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its data file metadata (chunk_index, file_index, etc.) + """ + file_to_episodes: dict[Path, set[int]] = {} + for old_idx in episode_mapping: + file_path = src_dataset.meta.get_data_file_path(old_idx) + if file_path not in file_to_episodes: + file_to_episodes[file_path] = set() + file_to_episodes[file_path].add(old_idx) + + global_index = 0 + episode_data_metadata: dict[int, dict] = {} + + if dst_meta.tasks is None: + all_task_indices = set() + for src_path in file_to_episodes: + df = pd.read_parquet(src_dataset.root / src_path) + mask = df["episode_index"].isin(list(episode_mapping.keys())) + task_series: pd.Series = df[mask]["task_index"] + all_task_indices.update(task_series.unique().tolist()) + tasks = [src_dataset.meta.tasks.iloc[idx].name for idx in all_task_indices] + dst_meta.save_episode_tasks(list(set(tasks))) + + task_mapping = {} + for old_task_idx in range(len(src_dataset.meta.tasks)): + task_name = src_dataset.meta.tasks.iloc[old_task_idx].name + new_task_idx = dst_meta.get_task_index(task_name) + if new_task_idx is not None: + task_mapping[old_task_idx] = new_task_idx + + for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"): + df = pd.read_parquet(src_dataset.root / src_path) + + all_episodes_in_file = set(df["episode_index"].unique()) + episodes_to_keep = file_to_episodes[src_path] + + if all_episodes_in_file == episodes_to_keep: + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + else: + mask = df["episode_index"].isin(list(episode_mapping.keys())) + df = df[mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + df["episode_index"] = df["episode_index"].replace(episode_mapping) + df["index"] = range(global_index, global_index + len(df)) + df["task_index"] = df["task_index"].replace(task_mapping) + + first_ep_old_idx = min(episodes_to_keep) + src_ep = src_dataset.meta.episodes[first_ep_old_idx] + chunk_idx = src_ep["data/chunk_index"] + file_idx = src_ep["data/file_index"] + + dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + dst_path.parent.mkdir(parents=True, exist_ok=True) + + if len(dst_meta.image_keys) > 0: + to_parquet_with_hf_images(df, dst_path) + else: + df.to_parquet(dst_path, index=False) + + for ep_old_idx in episodes_to_keep: + ep_new_idx = episode_mapping[ep_old_idx] + ep_df = df[df["episode_index"] == ep_new_idx] + episode_data_metadata[ep_new_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + global_index += len(df) + + return episode_data_metadata + + +def _keep_episodes_from_video_with_av( + input_path: Path, + output_path: Path, + episodes_to_keep: list[tuple[float, float]], + fps: float, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> None: + """Keep only specified episodes from a video file using PyAV. + + This function decodes frames from specified time ranges and re-encodes them with + properly reset timestamps to ensure monotonic progression. + + Args: + input_path: Source video file path. + output_path: Destination video file path. + episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + fps: Frame rate of the video. + vcodec: Video codec to use for encoding. + pix_fmt: Pixel format for output video. + """ + from fractions import Fraction + + import av + + if not episodes_to_keep: + raise ValueError("No episodes to keep") + + in_container = av.open(str(input_path)) + + # Check if video stream exists. + if not in_container.streams.video: + raise ValueError( + f"No video streams found in {input_path}. " + "The video file may be corrupted or empty. " + "Try re-downloading the dataset or checking the video file." + ) + + v_in = in_container.streams.video[0] + + out = av.open(str(output_path), mode="w") + + # Convert fps to Fraction for PyAV compatibility. + fps_fraction = Fraction(fps).limit_denominator(1000) + v_out = out.add_stream(vcodec, rate=fps_fraction) + + # PyAV type stubs don't distinguish video streams from audio/subtitle streams. + v_out.width = v_in.codec_context.width + v_out.height = v_in.codec_context.height + v_out.pix_fmt = pix_fmt + + # Set time_base to match the frame rate for proper timestamp handling. + v_out.time_base = Fraction(1, int(fps)) + + out.start_encoding() + + # Create set of (start, end) ranges for fast lookup. + # Convert to a sorted list for efficient checking. + time_ranges = sorted(episodes_to_keep) + + # Track frame index for setting PTS and current range being processed. + frame_count = 0 + range_idx = 0 + + # Read through entire video once and filter frames. + for packet in in_container.demux(v_in): + for frame in packet.decode(): + if frame is None: + continue + + # Get frame timestamp. + frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 + + # Check if frame is in any of our desired time ranges. + # Skip ranges that have already passed. + while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + range_idx += 1 + + # If we've passed all ranges, stop processing. + if range_idx >= len(time_ranges): + break + + # Check if frame is in current range. + start_ts, end_ts = time_ranges[range_idx] + if frame_time < start_ts: + continue + + # Frame is in range - create a new frame with reset timestamps. + # We need to create a copy to avoid modifying the original. + new_frame = frame.reformat(width=v_out.width, height=v_out.height, format=v_out.pix_fmt) + new_frame.pts = frame_count + new_frame.time_base = Fraction(1, int(fps)) + + # Encode and mux the frame. + for pkt in v_out.encode(new_frame): + out.mux(pkt) + + frame_count += 1 + + # Flush encoder. + for pkt in v_out.encode(): + out.mux(pkt) + + out.close() + in_container.close() + + +def _copy_and_reindex_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", +) -> dict[int, dict]: + """Copy and filter video files, only re-encoding files with deleted episodes. + + For video files that only contain kept episodes, we copy them directly. + For files with mixed kept/deleted episodes, we use PyAV filters to efficiently + re-encode only the desired segments. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + + Returns: + dict mapping episode index to its video metadata (chunk_index, file_index, timestamps) + """ + + episodes_video_metadata: dict[int, dict] = {new_idx: {} for new_idx in episode_mapping.values()} + + for video_key in src_dataset.meta.video_keys: + logging.info(f"Processing videos for {video_key}") + + if dst_meta.video_path is None: + raise ValueError("Destination metadata has no video_path defined") + + file_to_episodes: dict[tuple[int, int], list[int]] = {} + for old_idx in episode_mapping: + src_ep = src_dataset.meta.episodes[old_idx] + chunk_idx = src_ep[f"videos/{video_key}/chunk_index"] + file_idx = src_ep[f"videos/{video_key}/file_index"] + file_key = (chunk_idx, file_idx) + if file_key not in file_to_episodes: + file_to_episodes[file_key] = [] + file_to_episodes[file_key].append(old_idx) + + for (src_chunk_idx, src_file_idx), episodes_in_file in tqdm( + sorted(file_to_episodes.items()), desc=f"Processing {video_key} video files" + ): + all_episodes_in_file = [ + ep_idx + for ep_idx in range(src_dataset.meta.total_episodes) + if src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/chunk_index") == src_chunk_idx + and src_dataset.meta.episodes[ep_idx].get(f"videos/{video_key}/file_index") == src_file_idx + ] + + episodes_to_keep_set = set(episodes_in_file) + all_in_file_set = set(all_episodes_in_file) + + if all_in_file_set == episodes_to_keep_set: + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_video_path, dst_video_path) + + for old_idx in episodes_in_file: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = src_ep[ + f"videos/{video_key}/from_timestamp" + ] + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = src_ep[ + f"videos/{video_key}/to_timestamp" + ] + else: + # Build list of time ranges to keep, in sorted order. + sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) + episodes_to_keep_ranges: list[tuple[float, float]] = [] + + for old_idx in sorted_keep_episodes: + src_ep = src_dataset.meta.episodes[old_idx] + from_ts = src_ep[f"videos/{video_key}/from_timestamp"] + to_ts = src_ep[f"videos/{video_key}/to_timestamp"] + episodes_to_keep_ranges.append((from_ts, to_ts)) + + # Use PyAV filters to efficiently re-encode only the desired segments. + assert src_dataset.meta.video_path is not None + src_video_path = src_dataset.root / src_dataset.meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path = dst_meta.root / dst_meta.video_path.format( + video_key=video_key, chunk_index=src_chunk_idx, file_index=src_file_idx + ) + dst_video_path.parent.mkdir(parents=True, exist_ok=True) + + logging.info( + f"Re-encoding {video_key} (chunk {src_chunk_idx}, file {src_file_idx}) " + f"with {len(episodes_to_keep_ranges)} episodes" + ) + _keep_episodes_from_video_with_av( + src_video_path, + dst_video_path, + episodes_to_keep_ranges, + src_dataset.meta.fps, + vcodec, + pix_fmt, + ) + + cumulative_ts = 0.0 + for old_idx in sorted_keep_episodes: + new_idx = episode_mapping[old_idx] + src_ep = src_dataset.meta.episodes[old_idx] + ep_length = src_ep["length"] + ep_duration = ep_length / src_dataset.meta.fps + + episodes_video_metadata[new_idx][f"videos/{video_key}/chunk_index"] = src_chunk_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/file_index"] = src_file_idx + episodes_video_metadata[new_idx][f"videos/{video_key}/from_timestamp"] = cumulative_ts + episodes_video_metadata[new_idx][f"videos/{video_key}/to_timestamp"] = ( + cumulative_ts + ep_duration + ) + + cumulative_ts += ep_duration + + return episodes_video_metadata + + +def _copy_and_reindex_episodes_metadata( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_mapping: dict[int, int], + data_metadata: dict[int, dict], + video_metadata: dict[int, dict] | None = None, +) -> None: + """Copy and reindex episodes metadata using provided data and video metadata. + + Args: + src_dataset: Source dataset to copy from + dst_meta: Destination metadata object + episode_mapping: Mapping from old episode indices to new indices + data_metadata: Dict mapping new episode index to its data file metadata + video_metadata: Optional dict mapping new episode index to its video metadata + """ + from lerobot.datasets.utils import flatten_dict + + all_stats = [] + total_frames = 0 + + for old_idx, new_idx in tqdm( + sorted(episode_mapping.items(), key=lambda x: x[1]), desc="Processing episodes metadata" + ): + src_episode_full = _load_episode_with_stats(src_dataset, old_idx) + + src_episode = src_dataset.meta.episodes[old_idx] + + episode_meta = data_metadata[new_idx].copy() + + if video_metadata and new_idx in video_metadata: + episode_meta.update(video_metadata[new_idx]) + + # Extract episode statistics from parquet metadata. + # Note (maractingi): When pandas/pyarrow serializes numpy arrays with shape (3, 1, 1) to parquet, + # they are being deserialized as nested object arrays like: + # array([array([array([0.])]), array([array([0.])]), array([array([0.])])]) + # This happens particularly with image/video statistics. We need to detect and flatten + # these nested structures back to proper (3, 1, 1) arrays so aggregate_stats can process them. + episode_stats = {} + for key in src_episode_full: + if key.startswith("stats/"): + stat_key = key.replace("stats/", "") + parts = stat_key.split("/") + if len(parts) == 2: + feature_name, stat_name = parts + if feature_name not in episode_stats: + episode_stats[feature_name] = {} + + value = src_episode_full[key] + + if feature_name in src_dataset.meta.features: + feature_dtype = src_dataset.meta.features[feature_name]["dtype"] + if feature_dtype in ["image", "video"] and stat_name != "count": + if isinstance(value, np.ndarray) and value.dtype == object: + flat_values = [] + for item in value: + while isinstance(item, np.ndarray): + item = item.flatten()[0] + flat_values.append(item) + value = np.array(flat_values, dtype=np.float64).reshape(3, 1, 1) + elif isinstance(value, np.ndarray) and value.shape == (3,): + value = value.reshape(3, 1, 1) + + episode_stats[feature_name][stat_name] = value + + all_stats.append(episode_stats) + + episode_dict = { + "episode_index": new_idx, + "tasks": src_episode["tasks"], + "length": src_episode["length"], + } + episode_dict.update(episode_meta) + episode_dict.update(flatten_dict({"stats": episode_stats})) + dst_meta._save_episode_metadata(episode_dict) + + total_frames += src_episode["length"] + + dst_meta.info.update( + { + "total_episodes": len(episode_mapping), + "total_frames": total_frames, + "total_tasks": len(dst_meta.tasks) if dst_meta.tasks is not None else 0, + "splits": {"train": f"0:{len(episode_mapping)}"}, + } + ) + write_info(dst_meta.info, dst_meta.root) + + if not all_stats: + logging.warning("No statistics found to aggregate") + return + + logging.info(f"Aggregating statistics for {len(all_stats)} episodes") + aggregated_stats = aggregate_stats(all_stats) + filtered_stats = {k: v for k, v in aggregated_stats.items() if k in dst_meta.features} + write_stats(filtered_stats, dst_meta.root) + + +def _save_data_chunk( + df: pd.DataFrame, + meta: LeRobotDatasetMetadata, + chunk_idx: int = 0, + file_idx: int = 0, +) -> tuple[int, int, dict[int, dict]]: + """Save a data chunk and return updated indices and episode metadata. + + Returns: + tuple: (next_chunk_idx, next_file_idx, episode_metadata_dict) + where episode_metadata_dict maps episode_index to its data file metadata + """ + path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx) + path.parent.mkdir(parents=True, exist_ok=True) + + if len(meta.image_keys) > 0: + to_parquet_with_hf_images(df, path) + else: + df.to_parquet(path, index=False) + + episode_metadata = {} + for ep_idx in df["episode_index"].unique(): + ep_df = df[df["episode_index"] == ep_idx] + episode_metadata[ep_idx] = { + "data/chunk_index": chunk_idx, + "data/file_index": file_idx, + "dataset_from_index": int(ep_df["index"].min()), + "dataset_to_index": int(ep_df["index"].max() + 1), + } + + file_size = get_parquet_file_size_in_mb(path) + if file_size >= DEFAULT_DATA_FILE_SIZE_IN_MB * 0.9: + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, DEFAULT_CHUNK_SIZE) + + return chunk_idx, file_idx, episode_metadata + + +def _copy_data_with_feature_changes( + dataset: LeRobotDataset, + new_meta: LeRobotDatasetMetadata, + add_features: dict[str, tuple] | None = None, + remove_features: list[str] | None = None, +) -> None: + """Copy data while adding or removing features.""" + file_paths = set() + for ep_idx in range(dataset.meta.total_episodes): + file_paths.add(dataset.meta.get_data_file_path(ep_idx)) + + frame_idx = 0 + + for src_path in tqdm(sorted(file_paths), desc="Processing data files"): + df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True) + + if remove_features: + df = df.drop(columns=remove_features, errors="ignore") + + if add_features: + for feature_name, (values, _) in add_features.items(): + if callable(values): + feature_values = [] + for _, row in df.iterrows(): + ep_idx = row["episode_index"] + frame_in_ep = row["frame_index"] + value = values(row.to_dict(), ep_idx, frame_in_ep) + if isinstance(value, np.ndarray) and value.size == 1: + value = value.item() + feature_values.append(value) + df[feature_name] = feature_values + else: + end_idx = frame_idx + len(df) + feature_slice = values[frame_idx:end_idx] + if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1: + df[feature_name] = feature_slice.flatten() + else: + df[feature_name] = feature_slice + frame_idx = end_idx + + _save_data_chunk(df, new_meta) + + _copy_episodes_metadata_and_stats(dataset, new_meta) + + +def _copy_videos( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + exclude_keys: list[str] | None = None, +) -> None: + """Copy video files, optionally excluding certain keys.""" + if exclude_keys is None: + exclude_keys = [] + + for video_key in src_dataset.meta.video_keys: + if video_key in exclude_keys: + continue + + video_files = set() + for ep_idx in range(len(src_dataset.meta.episodes)): + try: + video_files.add(src_dataset.meta.get_video_file_path(ep_idx, video_key)) + except KeyError: + continue + + for src_path in tqdm(sorted(video_files), desc=f"Copying {video_key} videos"): + dst_path = dst_meta.root / src_path + dst_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy(src_dataset.root / src_path, dst_path) + + +def _copy_episodes_metadata_and_stats( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, +) -> None: + """Copy episodes metadata and recalculate stats.""" + if src_dataset.meta.tasks is not None: + write_tasks(src_dataset.meta.tasks, dst_meta.root) + dst_meta.tasks = src_dataset.meta.tasks.copy() + + episodes_dir = src_dataset.root / "meta/episodes" + dst_episodes_dir = dst_meta.root / "meta/episodes" + if episodes_dir.exists(): + shutil.copytree(episodes_dir, dst_episodes_dir, dirs_exist_ok=True) + + dst_meta.info.update( + { + "total_episodes": src_dataset.meta.total_episodes, + "total_frames": src_dataset.meta.total_frames, + "total_tasks": src_dataset.meta.total_tasks, + "splits": src_dataset.meta.info.get("splits", {"train": f"0:{src_dataset.meta.total_episodes}"}), + } + ) + + if dst_meta.video_keys and src_dataset.meta.video_keys: + for key in dst_meta.video_keys: + if key in src_dataset.meta.features: + dst_meta.info["features"][key]["info"] = src_dataset.meta.info["features"][key].get( + "info", {} + ) + + write_info(dst_meta.info, dst_meta.root) + + if set(dst_meta.features.keys()) != set(src_dataset.meta.features.keys()): + logging.info("Recalculating dataset statistics...") + if src_dataset.meta.stats: + new_stats = {} + for key in dst_meta.features: + if key in src_dataset.meta.stats: + new_stats[key] = src_dataset.meta.stats[key] + write_stats(new_stats, dst_meta.root) + else: + if src_dataset.meta.stats: + write_stats(src_dataset.meta.stats, dst_meta.root) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b661b21b0..229d37641 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -438,6 +438,9 @@ class LeRobotDatasetMetadata: robot_type: str | None = None, root: str | Path | None = None, use_videos: bool = True, + chunks_size: int | None = None, + data_files_size_in_mb: int | None = None, + video_files_size_in_mb: int | None = None, ) -> "LeRobotDatasetMetadata": """Creates metadata for a LeRobotDataset.""" obj = cls.__new__(cls) @@ -452,7 +455,16 @@ class LeRobotDatasetMetadata: obj.tasks = None obj.episodes = None obj.stats = None - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, + fps, + features, + use_videos, + robot_type, + chunks_size, + data_files_size_in_mb, + video_files_size_in_mb, + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index a2f285014..422a7010a 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -30,7 +30,7 @@ import pandas import pandas as pd import pyarrow.parquet as pq import torch -from datasets import Dataset, concatenate_datasets +from datasets import Dataset from datasets.table import embed_table_storage from huggingface_hub import DatasetCard, DatasetCardData, HfApi from huggingface_hub.errors import RevisionNotFoundError @@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import ( ForwardCompatibilityError, ) from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR -from lerobot.utils.utils import is_valid_numpy_dtype_string +from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file @@ -123,8 +123,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] - return concatenate_datasets(datasets) + with SuppressProgressBars(): + datasets = Dataset.from_parquet([str(path) for path in paths], features=features) + return datasets def get_parquet_num_frames(parquet_path: str | Path) -> int: diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 1d4f07c76..620ba863a 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -452,6 +452,9 @@ def concatenate_video_files( template=input_stream, opaque=True ) + # set the time base to the input stream time base (missing in the codec context) + stream_map[input_stream.index].time_base = input_stream.time_base + # Demux + remux packets (no re-encode) for packet in input_container.demux(): # Skip packets from un-mapped streams diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py new file mode 100644 index 000000000..83ba027bc --- /dev/null +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Edit LeRobot datasets using various transformation tools. + +This script allows you to delete episodes, split datasets, merge datasets, +and remove features. When new_repo_id is specified, creates a new dataset. + +Usage Examples: + +Delete episodes 0, 2, and 5 from a dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Delete episodes and save to a new dataset: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_filtered \ + --operation.type delete_episodes \ + --operation.episode_indices "[0, 2, 5]" + +Split dataset by fractions: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.8, "val": 0.2}' + +Split dataset by episode indices: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' + +Split into more than two splits: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type split \ + --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' + +Merge multiple datasets: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_merged \ + --operation.type merge \ + --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" + +Remove camera feature: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type remove_feature \ + --operation.feature_names "['observation.images.top']" + +Using JSON config file: + python -m lerobot.scripts.lerobot_edit_dataset \ + --config_path path/to/edit_config.json +""" + +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path + +from lerobot.configs import parser +from lerobot.datasets.dataset_tools import ( + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.utils import init_logging + + +@dataclass +class DeleteEpisodesConfig: + type: str = "delete_episodes" + episode_indices: list[int] | None = None + + +@dataclass +class SplitConfig: + type: str = "split" + splits: dict[str, float | list[int]] | None = None + + +@dataclass +class MergeConfig: + type: str = "merge" + repo_ids: list[str] | None = None + + +@dataclass +class RemoveFeatureConfig: + type: str = "remove_feature" + feature_names: list[str] | None = None + + +@dataclass +class EditDatasetConfig: + repo_id: str + operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig + root: str | None = None + new_repo_id: str | None = None + push_to_hub: bool = False + + +def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]: + if new_repo_id: + output_repo_id = new_repo_id + output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id + else: + output_repo_id = repo_id + dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id + old_path = Path(str(dataset_path) + "_old") + + if dataset_path.exists(): + if old_path.exists(): + shutil.rmtree(old_path) + shutil.move(str(dataset_path), str(old_path)) + + output_dir = dataset_path + + return output_repo_id, output_dir + + +def handle_delete_episodes(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, DeleteEpisodesConfig): + raise ValueError("Operation config must be DeleteEpisodesConfig") + + if not cfg.operation.episode_indices: + raise ValueError("episode_indices must be specified for delete_episodes operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}") + new_dataset = delete_episodes( + dataset, + episode_indices=cfg.operation.episode_indices, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +def handle_split(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, SplitConfig): + raise ValueError("Operation config must be SplitConfig") + + if not cfg.operation.splits: + raise ValueError( + "splits dict must be specified with split names as keys and fractions/episode lists as values" + ) + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + + logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}") + split_datasets = split_dataset(dataset, splits=cfg.operation.splits) + + for split_name, split_ds in split_datasets.items(): + split_repo_id = f"{cfg.repo_id}_{split_name}" + logging.info( + f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing {split_name} split to hub as {split_repo_id}") + LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub() + + +def handle_merge(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, MergeConfig): + raise ValueError("Operation config must be MergeConfig") + + if not cfg.operation.repo_ids: + raise ValueError("repo_ids must be specified for merge operation") + + if not cfg.repo_id: + raise ValueError("repo_id must be specified as the output repository for merged dataset") + + logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge") + datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids] + + output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id + + logging.info(f"Merging datasets into {cfg.repo_id}") + merged_dataset = merge_datasets( + datasets, + output_repo_id=cfg.repo_id, + output_dir=output_dir, + ) + + logging.info(f"Merged dataset saved to {output_dir}") + logging.info( + f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}" + ) + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub() + + +def handle_remove_feature(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, RemoveFeatureConfig): + raise ValueError("Operation config must be RemoveFeatureConfig") + + if not cfg.operation.feature_names: + raise ValueError("feature_names must be specified for remove_feature operation") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}") + new_dataset = remove_feature( + dataset, + feature_names=cfg.operation.feature_names, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() + + +@parser.wrap() +def edit_dataset(cfg: EditDatasetConfig) -> None: + operation_type = cfg.operation.type + + if operation_type == "delete_episodes": + handle_delete_episodes(cfg) + elif operation_type == "split": + handle_split(cfg) + elif operation_type == "merge": + handle_merge(cfg) + elif operation_type == "remove_feature": + handle_remove_feature(cfg) + else: + raise ValueError( + f"Unknown operation type: {operation_type}\n" + f"Available operations: delete_episodes, split, merge, remove_feature" + ) + + +def main() -> None: + init_logging() + edit_dataset() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 8777d5a9d..dfcd4a6b1 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -27,6 +27,7 @@ from statistics import mean import numpy as np import torch +from datasets.utils.logging import disable_progress_bar, enable_progress_bar def inside_slurm(): @@ -247,6 +248,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float): return days, hours, minutes, seconds +class SuppressProgressBars: + """ + Context manager to suppress progress bars. + + Example + -------- + ```python + with SuppressProgressBars(): + # Code that would normally show progress bars + ``` + """ + + def __enter__(self): + disable_progress_bar() + + def __exit__(self, exc_type, exc_val, exc_tb): + enable_progress_bar() + + class TimerManager: """ Lightweight utility to measure elapsed time. diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 4f316f80e..b710a3a4b 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds): pass +def assert_video_timestamps_within_bounds(aggr_ds): + """Test that all video timestamps are within valid bounds for their respective video files. + + This catches bugs where timestamps point to frames beyond the actual video length, + which would cause "Invalid frame index" errors during data loading. + """ + try: + from torchcodec.decoders import VideoDecoder + except ImportError: + return + + for ep_idx in range(aggr_ds.num_episodes): + ep = aggr_ds.meta.episodes[ep_idx] + + for vid_key in aggr_ds.meta.video_keys: + from_ts = ep[f"videos/{vid_key}/from_timestamp"] + to_ts = ep[f"videos/{vid_key}/to_timestamp"] + video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key) + + if not video_path.exists(): + continue + + from_frame_idx = round(from_ts * aggr_ds.fps) + to_frame_idx = round(to_ts * aggr_ds.fps) + + try: + decoder = VideoDecoder(str(video_path)) + num_frames = len(decoder) + + # Verify timestamps don't exceed video bounds + assert from_frame_idx >= 0, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0" + ) + assert from_frame_idx < num_frames, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})" + ) + assert to_frame_idx <= num_frames, ( + f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})" + ) + assert from_frame_idx < to_frame_idx, ( + f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})" + ) + except Exception as e: + raise AssertionError( + f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}" + ) from e + + def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): """Test basic aggregation functionality with standard parameters.""" ds_0_num_frames = 400 @@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) @@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): assert_metadata_consistency(aggr_ds, ds_0, ds_1) assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) assert_video_frames_integrity(aggr_ds, ds_0, ds_1) + assert_video_timestamps_within_bounds(aggr_ds) assert_dataset_iteration_works(aggr_ds) # Check that multiple files were actually created due to small size limits @@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory): if video_dir.exists(): video_files = list(video_dir.rglob("*.mp4")) assert len(video_files) > 1, "Small file size limits should create multiple video files" + + +def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): + """Regression test for video timestamp bug when merging datasets. + + This test specifically checks that video timestamps are correctly calculated + and accumulated when merging multiple datasets. + """ + datasets = [] + for i in range(3): + ds = lerobot_dataset_factory( + root=tmp_path / f"regression_{i}", + repo_id=f"{DUMMY_REPO_ID}_regression_{i}", + total_episodes=2, + total_frames=100, + ) + datasets.append(ds) + + aggregate_datasets( + repo_ids=[ds.repo_id for ds in datasets], + roots=[ds.root for ds in datasets], + aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr", + aggr_root=tmp_path / "regression_aggr", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "regression_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr") + + assert_video_timestamps_within_bounds(aggr_ds) + + for i in range(len(aggr_ds)): + item = aggr_ds[i] + for key in aggr_ds.meta.video_keys: + assert key in item, f"Video key {key} missing from item {i}" + assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py new file mode 100644 index 000000000..fe117b35b --- /dev/null +++ b/tests/datasets/test_dataset_tools.py @@ -0,0 +1,891 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for dataset tools utilities.""" + +from unittest.mock import patch + +import numpy as np +import pytest +import torch + +from lerobot.datasets.dataset_tools import ( + add_feature, + delete_episodes, + merge_datasets, + remove_feature, + split_dataset, +) + + +@pytest.fixture +def sample_dataset(tmp_path, empty_lerobot_dataset_factory): + """Create a sample dataset for testing.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset", + features=features, + ) + + for ep_idx in range(5): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset.add_frame(frame) + dataset.save_episode() + + return dataset + + +def test_delete_single_episode(sample_dataset, tmp_path): + """Test deleting a single episode.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 4 + assert new_dataset.meta.total_frames == 40 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2, 3} + + assert len(new_dataset) == 40 + + +def test_delete_multiple_episodes(sample_dataset, tmp_path): + """Test deleting multiple episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]} + assert episode_indices == {0, 1, 2} + + +def test_delete_invalid_episodes(sample_dataset, tmp_path): + """Test error handling for invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + delete_episodes( + sample_dataset, + episode_indices=[10, 20], + output_dir=tmp_path / "filtered", + ) + + +def test_delete_all_episodes(sample_dataset, tmp_path): + """Test error when trying to delete all episodes.""" + with pytest.raises(ValueError, match="Cannot delete all episodes"): + delete_episodes( + sample_dataset, + episode_indices=list(range(5)), + output_dir=tmp_path / "filtered", + ) + + +def test_delete_empty_list(sample_dataset, tmp_path): + """Test error when no episodes specified.""" + with pytest.raises(ValueError, match="No episodes to delete"): + delete_episodes( + sample_dataset, + episode_indices=[], + output_dir=tmp_path / "filtered", + ) + + +def test_split_by_episodes(sample_dataset, tmp_path): + """Test splitting dataset by specific episode indices.""" + splits = { + "train": [0, 1, 2], + "val": [3, 4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + if "train" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_train") + elif "val" in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_val") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val"} + + assert result["train"].meta.total_episodes == 3 + assert result["train"].meta.total_frames == 30 + + assert result["val"].meta.total_episodes == 2 + assert result["val"].meta.total_frames == 20 + + train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]} + assert train_episodes == {0, 1, 2} + + val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]} + assert val_episodes == {0, 1} + + +def test_split_by_fractions(sample_dataset, tmp_path): + """Test splitting dataset by fractions.""" + splits = { + "train": 0.6, + "val": 0.4, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 2 + + +def test_split_overlapping_episodes(sample_dataset, tmp_path): + """Test error when episodes appear in multiple splits.""" + splits = { + "train": [0, 1, 2], + "val": [2, 3, 4], + } + + with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_invalid_fractions(sample_dataset, tmp_path): + """Test error when fractions sum to more than 1.""" + splits = { + "train": 0.7, + "val": 0.5, + } + + with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"): + split_dataset(sample_dataset, splits=splits, output_dir=tmp_path) + + +def test_split_empty(sample_dataset, tmp_path): + """Test error with empty splits.""" + with pytest.raises(ValueError, match="No splits provided"): + split_dataset(sample_dataset, splits={}, output_dir=tmp_path) + + +def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging two datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 8 # 5 + 3 + assert merged.meta.total_frames == 80 # 50 + 30 + + episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]}) + assert episode_indices == list(range(8)) + + +def test_merge_empty_list(tmp_path): + """Test error when merging empty list.""" + with pytest.raises(ValueError, match="No datasets to merge"): + merge_datasets([], output_repo_id="merged", output_dir=tmp_path) + + +def test_add_feature_with_values(sample_dataset, tmp_path): + """Test adding a feature with pre-computed values.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + assert new_dataset.meta.features["reward"] == feature_info + + assert len(new_dataset) == num_frames + sample_item = new_dataset[0] + assert "reward" in sample_item + assert isinstance(sample_item["reward"], torch.Tensor) + + +def test_add_feature_with_callable(sample_dataset, tmp_path): + """Test adding a feature with a callable.""" + + def compute_reward(frame_dict, episode_idx, frame_idx): + return float(episode_idx * 10 + frame_idx) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=compute_reward, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert "reward" in new_dataset.meta.features + + items = [new_dataset[i] for i in range(10)] + first_episode_items = [item for item in items if item["episode_index"] == 0] + assert len(first_episode_items) == 10 + + first_frame = first_episode_items[0] + assert first_frame["frame_index"] == 0 + assert float(first_frame["reward"]) == 0.0 + + +def test_add_existing_feature(sample_dataset, tmp_path): + """Test error when adding an existing feature.""" + feature_info = {"dtype": "float32", "shape": (1,)} + + with pytest.raises(ValueError, match="Feature 'action' already exists"): + add_feature( + sample_dataset, + feature_name="action", + feature_values=np.zeros(50), + feature_info=feature_info, + output_dir=tmp_path / "modified", + ) + + +def test_add_feature_invalid_info(sample_dataset, tmp_path): + """Test error with invalid feature info.""" + with pytest.raises(ValueError, match="feature_info must contain keys"): + add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.zeros(50), + feature_info={"dtype": "float32"}, + output_dir=tmp_path / "modified", + ) + + +def test_remove_single_feature(sample_dataset, tmp_path): + """Test removing a single feature.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + assert "reward" not in dataset_without_reward.meta.features + + sample_item = dataset_without_reward[0] + assert "reward" not in sample_item + + +def test_remove_multiple_features(sample_dataset, tmp_path): + """Test removing multiple features at once.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = sample_dataset + for feature_name in ["reward", "success"]: + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + dataset = add_feature( + dataset, + feature_name=feature_name, + feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / f"with_{feature_name}", + ) + + dataset_clean = remove_feature( + dataset, + feature_names=["reward", "success"], + output_dir=tmp_path / "clean", + ) + + assert "reward" not in dataset_clean.meta.features + assert "success" not in dataset_clean.meta.features + + +def test_remove_nonexistent_feature(sample_dataset, tmp_path): + """Test error when removing non-existent feature.""" + with pytest.raises(ValueError, match="Feature 'nonexistent' not found"): + remove_feature( + sample_dataset, + feature_names="nonexistent", + output_dir=tmp_path / "modified", + ) + + +def test_remove_required_feature(sample_dataset, tmp_path): + """Test error when trying to remove required features.""" + with pytest.raises(ValueError, match="Cannot remove required features"): + remove_feature( + sample_dataset, + feature_names="timestamp", + output_dir=tmp_path / "modified", + ) + + +def test_remove_camera_feature(sample_dataset, tmp_path): + """Test removing a camera feature.""" + camera_keys = sample_dataset.meta.camera_keys + if not camera_keys: + pytest.skip("No camera keys in dataset") + + camera_to_remove = camera_keys[0] + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "without_camera") + + dataset_without_camera = remove_feature( + sample_dataset, + feature_names=camera_to_remove, + output_dir=tmp_path / "without_camera", + ) + + assert camera_to_remove not in dataset_without_camera.meta.features + assert camera_to_remove not in dataset_without_camera.meta.camera_keys + + sample_item = dataset_without_camera[0] + assert camera_to_remove not in sample_item + + +def test_complex_workflow_integration(sample_dataset, tmp_path): + """Test a complex workflow combining multiple operations.""" + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info={"dtype": "float32", "shape": (1,), "names": None}, + output_dir=tmp_path / "step1", + ) + + dataset = delete_episodes( + dataset, + episode_indices=[2], + output_dir=tmp_path / "step2", + ) + + splits = split_dataset( + dataset, + splits={"train": 0.75, "val": 0.25}, + output_dir=tmp_path / "step3", + ) + + merged = merge_datasets( + list(splits.values()), + output_repo_id="final_dataset", + output_dir=tmp_path / "step4", + ) + + assert merged.meta.total_episodes == 4 + assert merged.meta.total_frames == 40 + assert "reward" in merged.meta.features + + assert len(merged) == 40 + sample_item = merged[0] + assert "reward" in sample_item + + +def test_delete_episodes_preserves_stats(sample_dataset, tmp_path): + """Test that deleting episodes preserves statistics correctly.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[2], + output_dir=output_dir, + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path): + """Test that tasks are preserved correctly after deletion.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0], + output_dir=output_dir, + ) + + assert new_dataset.meta.tasks is not None + assert len(new_dataset.meta.tasks) == 2 + + tasks_in_dataset = {str(item["task"]) for item in new_dataset} + assert len(tasks_in_dataset) > 0 + + +def test_split_three_ways(sample_dataset, tmp_path): + """Test splitting dataset into three splits.""" + splits = { + "train": 0.6, + "val": 0.2, + "test": 0.2, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + assert set(result.keys()) == {"train", "val", "test"} + assert result["train"].meta.total_episodes == 3 + assert result["val"].meta.total_episodes == 1 + assert result["test"].meta.total_episodes == 1 + + total_frames = sum(ds.meta.total_frames for ds in result.values()) + assert total_frames == sample_dataset.meta.total_frames + + +def test_split_preserves_stats(sample_dataset, tmp_path): + """Test that statistics are preserved when splitting.""" + splits = {"train": [0, 1, 2], "val": [3, 4]} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + for split_ds in result.values(): + assert split_ds.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in split_ds.meta.stats + assert "mean" in split_ds.meta.stats[feature] + assert "std" in split_ds.meta.stats[feature] + + +def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test merging three datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + datasets = [sample_dataset] + + for i in range(2): + dataset = empty_lerobot_dataset_factory( + root=tmp_path / f"test_dataset{i + 2}", + features=features, + ) + + for ep_idx in range(2): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx}", + } + dataset.add_frame(frame) + dataset.save_episode() + + datasets.append(dataset) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + datasets, + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.total_episodes == 9 + assert merged.meta.total_frames == 90 + + +def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory): + """Test that statistics are computed for merged datasets.""" + features = { + "action": {"dtype": "float32", "shape": (6,), "names": None}, + "observation.state": {"dtype": "float32", "shape": (4,), "names": None}, + "observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None}, + } + + dataset2 = empty_lerobot_dataset_factory( + root=tmp_path / "test_dataset2", + features=features, + ) + + for ep_idx in range(3): + for _ in range(10): + frame = { + "action": np.random.randn(6).astype(np.float32), + "observation.state": np.random.randn(4).astype(np.float32), + "observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8), + "task": f"task_{ep_idx % 2}", + } + dataset2.add_frame(frame) + dataset2.save_episode() + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "merged_dataset") + + merged = merge_datasets( + [sample_dataset, dataset2], + output_repo_id="merged_dataset", + output_dir=tmp_path / "merged_dataset", + ) + + assert merged.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in merged.meta.stats + assert "mean" in merged.meta.stats[feature] + assert "std" in merged.meta.stats[feature] + + +def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path): + """Test that adding a feature preserves existing stats.""" + num_frames = sample_dataset.meta.total_frames + reward_values = np.random.randn(num_frames, 1).astype(np.float32) + + feature_info = { + "dtype": "float32", + "shape": (1,), + "names": None, + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "with_reward") + + new_dataset = add_feature( + sample_dataset, + feature_name="reward", + feature_values=reward_values, + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + assert new_dataset.meta.stats is not None + for feature in ["action", "observation.state"]: + assert feature in new_dataset.meta.stats + assert "mean" in new_dataset.meta.stats[feature] + assert "std" in new_dataset.meta.stats[feature] + + +def test_remove_feature_updates_stats(sample_dataset, tmp_path): + """Test that removing a feature removes it from stats.""" + feature_info = {"dtype": "float32", "shape": (1,), "names": None} + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path)) + + dataset_with_reward = add_feature( + sample_dataset, + feature_name="reward", + feature_values=np.random.randn(50, 1).astype(np.float32), + feature_info=feature_info, + output_dir=tmp_path / "with_reward", + ) + + dataset_without_reward = remove_feature( + dataset_with_reward, + feature_names="reward", + output_dir=tmp_path / "without_reward", + ) + + if dataset_without_reward.meta.stats: + assert "reward" not in dataset_without_reward.meta.stats + + +def test_delete_consecutive_episodes(sample_dataset, tmp_path): + """Test deleting consecutive episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[1, 2, 3], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 2 + assert new_dataset.meta.total_frames == 20 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1] + + +def test_delete_first_and_last_episodes(sample_dataset, tmp_path): + """Test deleting first and last episodes.""" + output_dir = tmp_path / "filtered" + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(output_dir) + + new_dataset = delete_episodes( + sample_dataset, + episode_indices=[0, 4], + output_dir=output_dir, + ) + + assert new_dataset.meta.total_episodes == 3 + assert new_dataset.meta.total_frames == 30 + + episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}) + assert episode_indices == [0, 1, 2] + + +def test_split_all_episodes_assigned(sample_dataset, tmp_path): + """Test that all episodes can be explicitly assigned to splits.""" + splits = { + "split1": [0, 1], + "split2": [2, 3], + "split3": [4], + } + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + + def mock_snapshot(repo_id, **kwargs): + for split_name in splits: + if split_name in repo_id: + return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}") + return str(kwargs.get("local_dir", tmp_path)) + + mock_snapshot_download.side_effect = mock_snapshot + + result = split_dataset( + sample_dataset, + splits=splits, + output_dir=tmp_path, + ) + + total_episodes = sum(ds.meta.total_episodes for ds in result.values()) + assert total_episodes == sample_dataset.meta.total_episodes