Dataset tools (#2100)

* feat(dataset-tools): add dataset utilities and example script

- Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets.
- Added an example script demonstrating the usage of these utilities.
- Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.

* style fixes

* move example to dataset dir

* missing lisence

* fixes mostly path

* clean comments

* move tests to functions instead of class based

* - fix video editting, decode, delete frames and rencode video
- copy unchanged video and parquet files to avoid recreating the entire dataset

* Fortify tooling tests

* Fix type issue resulting from saving numpy arrays with shape 3,1,1

* added lerobot_edit_dataset

* - revert changes in examples
- remove hardcoded split names

* update comment

* fix comment
add lerobot-edit-dataset shortcut

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* style nit after copilot review

* fix: bug in dataset root when editing the dataset in place (without setting new_repo_id

* Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos

* Added missing output repo id

* migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again.
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>

* added modified suffix in case repo_id is not set in delete_episode

* adding docs for dataset tools

* bump av version and add back time_base assignment

* linter

* modified push_to_hub logic in lerobot_edit_dataset

* fix(progress bar): fixing the progress bar issue in dataset tools

* chore(concatenate): removing no longer needed concatenate_datasets usage

* fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets

* style fix

* refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging

There were three critical bugs in aggregate.py that prevented correct dataset merging:

1. Video file indices: Changed from += to = assignment to correctly reference
   merged video files

2. Video timestamps: Implemented per-source-file offset tracking to maintain
   continuous timestamps when merging split datasets (was causing non-monotonic
   timestamp warnings)

3. File rotation offsets: Store timestamp offsets after rotation decision to
   prevent out-of-bounds frame access (was causing "Invalid frame index" errors
   with small file size limits)

Changes:
- Updated update_meta_data() to apply per-source-file timestamp offsets
- Updated aggregate_videos() to track offsets correctly during file rotation
- Added get_video_duration_in_s import for duration calculation

* Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes

* chore(docs): update merge documentation details

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
Co-authored-by: Jack Vial <vialjack@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Michel Aractingi
2025-10-10 12:32:07 +02:00
committed by GitHub
parent 656fc0f059
commit b8f7e401d4
13 changed files with 2593 additions and 30 deletions

View File

@@ -25,6 +25,8 @@
title: Using LeRobotDataset
- local: porting_datasets_v3
title: Porting Large Datasets
- local: using_dataset_tools
title: Using the Dataset Tools
title: "Datasets"
- sections:
- local: act

View File

@@ -0,0 +1,102 @@
# Using Dataset Tools
This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
## Overview
LeRobot provides several utilities for manipulating datasets:
1. **Delete Episodes** - Remove specific episodes from a dataset
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
4. **Add Features** - Add new features to a dataset
5. **Remove Features** - Remove features from a dataset
The core implementation is in `lerobot.datasets.dataset_tools`.
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
## Command-Line Tool: lerobot-edit-dataset
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
### Usage Examples
#### Delete Episodes
Remove specific episodes from a dataset. This is useful for filtering out undesired data.
```bash
# Delete episodes 0, 2, and 5 (modifies original dataset)
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
# Delete episodes and save to a new dataset (preserves original dataset)
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
```
#### Split Dataset
Divide a dataset into multiple subsets.
```bash
# Split by fractions (e.g. 80% train, 20% test, 20% val)
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
# Split by specific episode indices
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
```
There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
#### Merge Datasets
Combine multiple datasets into a single dataset.
```bash
# Merge train and validation splits back into one dataset
lerobot-edit-dataset \
--repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
```
#### Remove Features
Remove features from a dataset.
```bash
# Remove a camera feature
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
```
### Push to Hub
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
```bash
lerobot-edit-dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_after_deletion \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]" \
--push_to_hub
```
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example script demonstrating dataset tools utilities.
This script shows how to:
1. Delete episodes from a dataset
2. Split a dataset into train/val sets
3. Add/remove features
4. Merge datasets
Usage:
python examples/dataset/use_dataset_tools.py
"""
import numpy as np
from lerobot.datasets.dataset_tools import (
add_feature,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
dataset = LeRobotDataset("lerobot/pusht")
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
print(f"Features: {list(dataset.meta.features.keys())}")
print("\n1. Deleting episodes 0 and 2...")
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
print("\n2. Splitting dataset into train/val...")
splits = split_dataset(
dataset,
splits={"train": 0.8, "val": 0.2},
)
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
print("\n3. Adding a reward feature...")
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
dataset_with_reward = add_feature(
dataset,
feature_name="reward",
feature_values=reward_values,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="lerobot/pusht_with_reward",
)
def compute_success(row_dict, episode_index, frame_index):
episode_length = 10
return float(frame_index >= episode_length - 10)
dataset_with_success = add_feature(
dataset_with_reward,
feature_name="success",
feature_values=compute_success,
feature_info={
"dtype": "float32",
"shape": (1,),
"names": None,
},
repo_id="lerobot/pusht_with_reward_and_success",
)
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
print("\n4. Removing the success feature...")
dataset_cleaned = remove_feature(
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
)
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
print("\n5. Merging train and val splits back together...")
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
print("\n6. Complex workflow example...")
if len(dataset.meta.camera_keys) > 1:
camera_to_remove = dataset.meta.camera_keys[0]
print(f"Removing camera: {camera_to_remove}")
dataset_no_cam = remove_feature(
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
)
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
if __name__ == "__main__":
main()

View File

@@ -67,7 +67,7 @@ dependencies = [
"cmake>=3.29.0.1,<4.2.0",
"einops>=0.8.0,<0.9.0",
"opencv-python-headless>=4.9.0,<4.13.0",
"av>=14.2.0,<16.0.0",
"av>=15.0.0,<16.0.0",
"jsonlines>=4.0.0,<5.0.0",
"packaging>=24.2,<26.0",
"pynput>=1.7.7,<1.9.0",
@@ -175,6 +175,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
lerobot-info="lerobot.scripts.lerobot_info:main"
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.packages.find]

View File

@@ -39,7 +39,7 @@ from lerobot.datasets.utils import (
write_stats,
write_tasks,
)
from lerobot.datasets.video_utils import concatenate_video_files
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
@@ -130,10 +130,34 @@ def update_meta_data(
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
# Store original video file indices before updating
orig_chunk_col = f"videos/{key}/chunk_index"
orig_file_col = f"videos/{key}/file_index"
df["_orig_chunk"] = df[orig_chunk_col].copy()
df["_orig_file"] = df[orig_file_col].copy()
# Update chunk and file indices to point to destination
df[orig_chunk_col] = video_idx["chunk"]
df[orig_file_col] = video_idx["file"]
# Apply per-source-file timestamp offsets
src_to_offset = video_idx.get("src_to_offset", {})
if src_to_offset:
# Apply offset based on original source file
for idx in df.index:
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
offset = src_to_offset.get(src_key, 0)
df.at[idx, f"videos/{key}/from_timestamp"] += offset
df.at[idx, f"videos/{key}/to_timestamp"] += offset
else:
# Fallback to simple offset (for backward compatibility)
df[f"videos/{key}/from_timestamp"] = (
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
# Clean up temporary columns
df = df.drop(columns=["_orig_chunk", "_orig_file"])
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
@@ -193,6 +217,9 @@ def aggregate_datasets(
robot_type=robot_type,
features=features,
root=aggr_root,
chunks_size=chunk_size,
data_files_size_in_mb=data_files_size_in_mb,
video_files_size_in_mb=video_files_size_in_mb,
)
logging.info("Find all tasks")
@@ -236,6 +263,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
Returns:
dict: Updated videos_idx with current chunk and file indices.
"""
for key in videos_idx:
videos_idx[key]["episode_duration"] = 0
# Track offset for each source (chunk, file) pair
videos_idx[key]["src_to_offset"] = {}
for key, video_idx in videos_idx.items():
unique_chunk_file_pairs = {
(chunk, file)
@@ -249,6 +281,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
chunk_idx = video_idx["chunk"]
file_idx = video_idx["file"]
current_offset = video_idx["latest_duration"]
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
@@ -263,21 +296,24 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
file_index=file_idx,
)
# If a new file is created, we don't want to increment the latest_duration
update_latest_duration = False
src_duration = get_video_duration_in_s(src_path)
if not dst_path.exists():
# First write to this destination file
# Store offset before incrementing
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
continue # not accumulating further, already copied the file in place
videos_idx[key]["episode_duration"] += src_duration
current_offset += src_duration
continue
# Check file sizes before appending
src_size = get_video_size_in_mb(src_path)
dst_size = get_video_size_in_mb(dst_path)
if dst_size + src_size >= video_files_size_in_mb:
# Rotate to a new chunk/file
# Rotate to a new file, this source becomes start of new destination
# So its offset should be 0
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
video_key=key,
@@ -286,25 +322,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
)
dst_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(str(src_path), str(dst_path))
# Reset offset for next file
current_offset = src_duration
else:
# Get the timestamps shift for this video
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
# Append to existing video file
# Append to existing video file - use current accumulated offset
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
concatenate_video_files(
[dst_path, src_path],
dst_path,
)
# Update the latest_duration when appending (shifts timestamps!)
update_latest_duration = not update_latest_duration
current_offset += src_duration
videos_idx[key]["episode_duration"] += src_duration
# Update the videos_idx with the final chunk and file indices for this key
videos_idx[key]["chunk"] = chunk_idx
videos_idx[key]["file"] = file_idx
if update_latest_duration:
videos_idx[key]["latest_duration"] += timestamps_shift_s
return videos_idx
@@ -389,9 +422,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
videos_idx,
)
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
meta_idx = append_or_create_parquet_file(
df,
src_path,
@@ -403,6 +433,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
aggr_root=dst_meta.root,
)
# Increment latest_duration by the total duration added from this source dataset
for k in videos_idx:
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
return meta_idx

File diff suppressed because it is too large Load Diff

View File

@@ -438,6 +438,9 @@ class LeRobotDatasetMetadata:
robot_type: str | None = None,
root: str | Path | None = None,
use_videos: bool = True,
chunks_size: int | None = None,
data_files_size_in_mb: int | None = None,
video_files_size_in_mb: int | None = None,
) -> "LeRobotDatasetMetadata":
"""Creates metadata for a LeRobotDataset."""
obj = cls.__new__(cls)
@@ -452,7 +455,16 @@ class LeRobotDatasetMetadata:
obj.tasks = None
obj.episodes = None
obj.stats = None
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
obj.info = create_empty_dataset_info(
CODEBASE_VERSION,
fps,
features,
use_videos,
robot_type,
chunks_size,
data_files_size_in_mb,
video_files_size_in_mb,
)
if len(obj.video_keys) > 0 and not use_videos:
raise ValueError()
write_json(obj.info, obj.root / INFO_PATH)

View File

@@ -30,7 +30,7 @@ import pandas
import pandas as pd
import pyarrow.parquet as pq
import torch
from datasets import Dataset, concatenate_datasets
from datasets import Dataset
from datasets.table import embed_table_storage
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
from huggingface_hub.errors import RevisionNotFoundError
@@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import (
ForwardCompatibilityError,
)
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
from lerobot.utils.utils import is_valid_numpy_dtype_string
from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
@@ -123,8 +123,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
return concatenate_datasets(datasets)
with SuppressProgressBars():
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
return datasets
def get_parquet_num_frames(parquet_path: str | Path) -> int:

View File

@@ -452,6 +452,9 @@ def concatenate_video_files(
template=input_stream, opaque=True
)
# set the time base to the input stream time base (missing in the codec context)
stream_map[input_stream.index].time_base = input_stream.time_base
# Demux + remux packets (no re-encode)
for packet in input_container.demux():
# Skip packets from un-mapped streams

View File

@@ -0,0 +1,286 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets,
and remove features. When new_repo_id is specified, creates a new dataset.
Usage Examples:
Delete episodes 0, 2, and 5 from a dataset:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Delete episodes and save to a new dataset:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_filtered \
--operation.type delete_episodes \
--operation.episode_indices "[0, 2, 5]"
Split dataset by fractions:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.8, "val": 0.2}'
Split dataset by episode indices:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
Split into more than two splits:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type split \
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
Merge multiple datasets:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht_merged \
--operation.type merge \
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
Remove camera feature:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type remove_feature \
--operation.feature_names "['observation.images.top']"
Using JSON config file:
python -m lerobot.scripts.lerobot_edit_dataset \
--config_path path/to/edit_config.json
"""
import logging
import shutil
from dataclasses import dataclass
from pathlib import Path
from lerobot.configs import parser
from lerobot.datasets.dataset_tools import (
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import HF_LEROBOT_HOME
from lerobot.utils.utils import init_logging
@dataclass
class DeleteEpisodesConfig:
type: str = "delete_episodes"
episode_indices: list[int] | None = None
@dataclass
class SplitConfig:
type: str = "split"
splits: dict[str, float | list[int]] | None = None
@dataclass
class MergeConfig:
type: str = "merge"
repo_ids: list[str] | None = None
@dataclass
class RemoveFeatureConfig:
type: str = "remove_feature"
feature_names: list[str] | None = None
@dataclass
class EditDatasetConfig:
repo_id: str
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
root: str | None = None
new_repo_id: str | None = None
push_to_hub: bool = False
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
if new_repo_id:
output_repo_id = new_repo_id
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
else:
output_repo_id = repo_id
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
old_path = Path(str(dataset_path) + "_old")
if dataset_path.exists():
if old_path.exists():
shutil.rmtree(old_path)
shutil.move(str(dataset_path), str(old_path))
output_dir = dataset_path
return output_repo_id, output_dir
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, DeleteEpisodesConfig):
raise ValueError("Operation config must be DeleteEpisodesConfig")
if not cfg.operation.episode_indices:
raise ValueError("episode_indices must be specified for delete_episodes operation")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
new_dataset = delete_episodes(
dataset,
episode_indices=cfg.operation.episode_indices,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_split(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, SplitConfig):
raise ValueError("Operation config must be SplitConfig")
if not cfg.operation.splits:
raise ValueError(
"splits dict must be specified with split names as keys and fractions/episode lists as values"
)
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
for split_name, split_ds in split_datasets.items():
split_repo_id = f"{cfg.repo_id}_{split_name}"
logging.info(
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
)
if cfg.push_to_hub:
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
def handle_merge(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, MergeConfig):
raise ValueError("Operation config must be MergeConfig")
if not cfg.operation.repo_ids:
raise ValueError("repo_ids must be specified for merge operation")
if not cfg.repo_id:
raise ValueError("repo_id must be specified as the output repository for merged dataset")
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
logging.info(f"Merging datasets into {cfg.repo_id}")
merged_dataset = merge_datasets(
datasets,
output_repo_id=cfg.repo_id,
output_dir=output_dir,
)
logging.info(f"Merged dataset saved to {output_dir}")
logging.info(
f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}"
)
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {cfg.repo_id}")
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
def handle_remove_feature(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, RemoveFeatureConfig):
raise ValueError("Operation config must be RemoveFeatureConfig")
if not cfg.operation.feature_names:
raise ValueError("feature_names must be specified for remove_feature operation")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
new_dataset = remove_feature(
dataset,
feature_names=cfg.operation.feature_names,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
@parser.wrap()
def edit_dataset(cfg: EditDatasetConfig) -> None:
operation_type = cfg.operation.type
if operation_type == "delete_episodes":
handle_delete_episodes(cfg)
elif operation_type == "split":
handle_split(cfg)
elif operation_type == "merge":
handle_merge(cfg)
elif operation_type == "remove_feature":
handle_remove_feature(cfg)
else:
raise ValueError(
f"Unknown operation type: {operation_type}\n"
f"Available operations: delete_episodes, split, merge, remove_feature"
)
def main() -> None:
init_logging()
edit_dataset()
if __name__ == "__main__":
main()

View File

@@ -27,6 +27,7 @@ from statistics import mean
import numpy as np
import torch
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
def inside_slurm():
@@ -247,6 +248,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
return days, hours, minutes, seconds
class SuppressProgressBars:
"""
Context manager to suppress progress bars.
Example
--------
```python
with SuppressProgressBars():
# Code that would normally show progress bars
```
"""
def __enter__(self):
disable_progress_bar()
def __exit__(self, exc_type, exc_val, exc_tb):
enable_progress_bar()
class TimerManager:
"""
Lightweight utility to measure elapsed time.

View File

@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
pass
def assert_video_timestamps_within_bounds(aggr_ds):
"""Test that all video timestamps are within valid bounds for their respective video files.
This catches bugs where timestamps point to frames beyond the actual video length,
which would cause "Invalid frame index" errors during data loading.
"""
try:
from torchcodec.decoders import VideoDecoder
except ImportError:
return
for ep_idx in range(aggr_ds.num_episodes):
ep = aggr_ds.meta.episodes[ep_idx]
for vid_key in aggr_ds.meta.video_keys:
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
continue
from_frame_idx = round(from_ts * aggr_ds.fps)
to_frame_idx = round(to_ts * aggr_ds.fps)
try:
decoder = VideoDecoder(str(video_path))
num_frames = len(decoder)
# Verify timestamps don't exceed video bounds
assert from_frame_idx >= 0, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
)
assert from_frame_idx < num_frames, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
)
assert to_frame_idx <= num_frames, (
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
)
assert from_frame_idx < to_frame_idx, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
)
except Exception as e:
raise AssertionError(
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
) from e
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
"""Regression test for video timestamp bug when merging datasets.
This test specifically checks that video timestamps are correctly calculated
and accumulated when merging multiple datasets.
"""
datasets = []
for i in range(3):
ds = lerobot_dataset_factory(
root=tmp_path / f"regression_{i}",
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
total_episodes=2,
total_frames=100,
)
datasets.append(ds)
aggregate_datasets(
repo_ids=[ds.repo_id for ds in datasets],
roots=[ds.root for ds in datasets],
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
aggr_root=tmp_path / "regression_aggr",
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
assert_video_timestamps_within_bounds(aggr_ds)
for i in range(len(aggr_ds)):
item = aggr_ds[i]
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"

View File

@@ -0,0 +1,891 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset tools utilities."""
from unittest.mock import patch
import numpy as np
import pytest
import torch
from lerobot.datasets.dataset_tools import (
add_feature,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
@pytest.fixture
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
"""Create a sample dataset for testing."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset",
features=features,
)
for ep_idx in range(5):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset.add_frame(frame)
dataset.save_episode()
return dataset
def test_delete_single_episode(sample_dataset, tmp_path):
"""Test deleting a single episode."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[2],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 4
assert new_dataset.meta.total_frames == 40
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2, 3}
assert len(new_dataset) == 40
def test_delete_multiple_episodes(sample_dataset, tmp_path):
"""Test deleting multiple episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[1, 3],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 3
assert new_dataset.meta.total_frames == 30
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2}
def test_delete_invalid_episodes(sample_dataset, tmp_path):
"""Test error handling for invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
delete_episodes(
sample_dataset,
episode_indices=[10, 20],
output_dir=tmp_path / "filtered",
)
def test_delete_all_episodes(sample_dataset, tmp_path):
"""Test error when trying to delete all episodes."""
with pytest.raises(ValueError, match="Cannot delete all episodes"):
delete_episodes(
sample_dataset,
episode_indices=list(range(5)),
output_dir=tmp_path / "filtered",
)
def test_delete_empty_list(sample_dataset, tmp_path):
"""Test error when no episodes specified."""
with pytest.raises(ValueError, match="No episodes to delete"):
delete_episodes(
sample_dataset,
episode_indices=[],
output_dir=tmp_path / "filtered",
)
def test_split_by_episodes(sample_dataset, tmp_path):
"""Test splitting dataset by specific episode indices."""
splits = {
"train": [0, 1, 2],
"val": [3, 4],
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
if "train" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_train")
elif "val" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_val")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert set(result.keys()) == {"train", "val"}
assert result["train"].meta.total_episodes == 3
assert result["train"].meta.total_frames == 30
assert result["val"].meta.total_episodes == 2
assert result["val"].meta.total_frames == 20
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
assert train_episodes == {0, 1, 2}
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
assert val_episodes == {0, 1}
def test_split_by_fractions(sample_dataset, tmp_path):
"""Test splitting dataset by fractions."""
splits = {
"train": 0.6,
"val": 0.4,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert result["train"].meta.total_episodes == 3
assert result["val"].meta.total_episodes == 2
def test_split_overlapping_episodes(sample_dataset, tmp_path):
"""Test error when episodes appear in multiple splits."""
splits = {
"train": [0, 1, 2],
"val": [2, 3, 4],
}
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_invalid_fractions(sample_dataset, tmp_path):
"""Test error when fractions sum to more than 1."""
splits = {
"train": 0.7,
"val": 0.5,
}
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_empty(sample_dataset, tmp_path):
"""Test error with empty splits."""
with pytest.raises(ValueError, match="No splits provided"):
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test merging two datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset2",
features=features,
)
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
[sample_dataset, dataset2],
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.total_episodes == 8 # 5 + 3
assert merged.meta.total_frames == 80 # 50 + 30
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
assert episode_indices == list(range(8))
def test_merge_empty_list(tmp_path):
"""Test error when merging empty list."""
with pytest.raises(ValueError, match="No datasets to merge"):
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
def test_add_feature_with_values(sample_dataset, tmp_path):
"""Test adding a feature with pre-computed values."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert "reward" in new_dataset.meta.features
assert new_dataset.meta.features["reward"] == feature_info
assert len(new_dataset) == num_frames
sample_item = new_dataset[0]
assert "reward" in sample_item
assert isinstance(sample_item["reward"], torch.Tensor)
def test_add_feature_with_callable(sample_dataset, tmp_path):
"""Test adding a feature with a callable."""
def compute_reward(frame_dict, episode_idx, frame_idx):
return float(episode_idx * 10 + frame_idx)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=compute_reward,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert "reward" in new_dataset.meta.features
items = [new_dataset[i] for i in range(10)]
first_episode_items = [item for item in items if item["episode_index"] == 0]
assert len(first_episode_items) == 10
first_frame = first_episode_items[0]
assert first_frame["frame_index"] == 0
assert float(first_frame["reward"]) == 0.0
def test_add_existing_feature(sample_dataset, tmp_path):
"""Test error when adding an existing feature."""
feature_info = {"dtype": "float32", "shape": (1,)}
with pytest.raises(ValueError, match="Feature 'action' already exists"):
add_feature(
sample_dataset,
feature_name="action",
feature_values=np.zeros(50),
feature_info=feature_info,
output_dir=tmp_path / "modified",
)
def test_add_feature_invalid_info(sample_dataset, tmp_path):
"""Test error with invalid feature info."""
with pytest.raises(ValueError, match="feature_info must contain keys"):
add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.zeros(50),
feature_info={"dtype": "float32"},
output_dir=tmp_path / "modified",
)
def test_remove_single_feature(sample_dataset, tmp_path):
"""Test removing a single feature."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
dataset_without_reward = remove_feature(
dataset_with_reward,
feature_names="reward",
output_dir=tmp_path / "without_reward",
)
assert "reward" not in dataset_without_reward.meta.features
sample_item = dataset_without_reward[0]
assert "reward" not in sample_item
def test_remove_multiple_features(sample_dataset, tmp_path):
"""Test removing multiple features at once."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = sample_dataset
for feature_name in ["reward", "success"]:
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
dataset = add_feature(
dataset,
feature_name=feature_name,
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / f"with_{feature_name}",
)
dataset_clean = remove_feature(
dataset,
feature_names=["reward", "success"],
output_dir=tmp_path / "clean",
)
assert "reward" not in dataset_clean.meta.features
assert "success" not in dataset_clean.meta.features
def test_remove_nonexistent_feature(sample_dataset, tmp_path):
"""Test error when removing non-existent feature."""
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
remove_feature(
sample_dataset,
feature_names="nonexistent",
output_dir=tmp_path / "modified",
)
def test_remove_required_feature(sample_dataset, tmp_path):
"""Test error when trying to remove required features."""
with pytest.raises(ValueError, match="Cannot remove required features"):
remove_feature(
sample_dataset,
feature_names="timestamp",
output_dir=tmp_path / "modified",
)
def test_remove_camera_feature(sample_dataset, tmp_path):
"""Test removing a camera feature."""
camera_keys = sample_dataset.meta.camera_keys
if not camera_keys:
pytest.skip("No camera keys in dataset")
camera_to_remove = camera_keys[0]
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
dataset_without_camera = remove_feature(
sample_dataset,
feature_names=camera_to_remove,
output_dir=tmp_path / "without_camera",
)
assert camera_to_remove not in dataset_without_camera.meta.features
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
sample_item = dataset_without_camera[0]
assert camera_to_remove not in sample_item
def test_complex_workflow_integration(sample_dataset, tmp_path):
"""Test a complex workflow combining multiple operations."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info={"dtype": "float32", "shape": (1,), "names": None},
output_dir=tmp_path / "step1",
)
dataset = delete_episodes(
dataset,
episode_indices=[2],
output_dir=tmp_path / "step2",
)
splits = split_dataset(
dataset,
splits={"train": 0.75, "val": 0.25},
output_dir=tmp_path / "step3",
)
merged = merge_datasets(
list(splits.values()),
output_repo_id="final_dataset",
output_dir=tmp_path / "step4",
)
assert merged.meta.total_episodes == 4
assert merged.meta.total_frames == 40
assert "reward" in merged.meta.features
assert len(merged) == 40
sample_item = merged[0]
assert "reward" in sample_item
def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
"""Test that deleting episodes preserves statistics correctly."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[2],
output_dir=output_dir,
)
assert new_dataset.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in new_dataset.meta.stats
assert "mean" in new_dataset.meta.stats[feature]
assert "std" in new_dataset.meta.stats[feature]
def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
"""Test that tasks are preserved correctly after deletion."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[0],
output_dir=output_dir,
)
assert new_dataset.meta.tasks is not None
assert len(new_dataset.meta.tasks) == 2
tasks_in_dataset = {str(item["task"]) for item in new_dataset}
assert len(tasks_in_dataset) > 0
def test_split_three_ways(sample_dataset, tmp_path):
"""Test splitting dataset into three splits."""
splits = {
"train": 0.6,
"val": 0.2,
"test": 0.2,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert set(result.keys()) == {"train", "val", "test"}
assert result["train"].meta.total_episodes == 3
assert result["val"].meta.total_episodes == 1
assert result["test"].meta.total_episodes == 1
total_frames = sum(ds.meta.total_frames for ds in result.values())
assert total_frames == sample_dataset.meta.total_frames
def test_split_preserves_stats(sample_dataset, tmp_path):
"""Test that statistics are preserved when splitting."""
splits = {"train": [0, 1, 2], "val": [3, 4]}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
for split_ds in result.values():
assert split_ds.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in split_ds.meta.stats
assert "mean" in split_ds.meta.stats[feature]
assert "std" in split_ds.meta.stats[feature]
def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test merging three datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
datasets = [sample_dataset]
for i in range(2):
dataset = empty_lerobot_dataset_factory(
root=tmp_path / f"test_dataset{i + 2}",
features=features,
)
for ep_idx in range(2):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx}",
}
dataset.add_frame(frame)
dataset.save_episode()
datasets.append(dataset)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
datasets,
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.total_episodes == 9
assert merged.meta.total_frames == 90
def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test that statistics are computed for merged datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset2",
features=features,
)
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
[sample_dataset, dataset2],
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in merged.meta.stats
assert "mean" in merged.meta.stats[feature]
assert "std" in merged.meta.stats[feature]
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
"""Test that adding a feature preserves existing stats."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert new_dataset.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in new_dataset.meta.stats
assert "mean" in new_dataset.meta.stats[feature]
assert "std" in new_dataset.meta.stats[feature]
def test_remove_feature_updates_stats(sample_dataset, tmp_path):
"""Test that removing a feature removes it from stats."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
dataset_without_reward = remove_feature(
dataset_with_reward,
feature_names="reward",
output_dir=tmp_path / "without_reward",
)
if dataset_without_reward.meta.stats:
assert "reward" not in dataset_without_reward.meta.stats
def test_delete_consecutive_episodes(sample_dataset, tmp_path):
"""Test deleting consecutive episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[1, 2, 3],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 2
assert new_dataset.meta.total_frames == 20
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
assert episode_indices == [0, 1]
def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
"""Test deleting first and last episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[0, 4],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 3
assert new_dataset.meta.total_frames == 30
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
assert episode_indices == [0, 1, 2]
def test_split_all_episodes_assigned(sample_dataset, tmp_path):
"""Test that all episodes can be explicitly assigned to splits."""
splits = {
"split1": [0, 1],
"split2": [2, 3],
"split3": [4],
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
assert total_episodes == sample_dataset.meta.total_episodes