Dataset tools (#2100)

* feat(dataset-tools): add dataset utilities and example script

- Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets.
- Added an example script demonstrating the usage of these utilities.
- Implemented comprehensive tests for all new functionalities to ensure reliability and correctness.

* style fixes

* move example to dataset dir

* missing lisence

* fixes mostly path

* clean comments

* move tests to functions instead of class based

* - fix video editting, decode, delete frames and rencode video
- copy unchanged video and parquet files to avoid recreating the entire dataset

* Fortify tooling tests

* Fix type issue resulting from saving numpy arrays with shape 3,1,1

* added lerobot_edit_dataset

* - revert changes in examples
- remove hardcoded split names

* update comment

* fix comment
add lerobot-edit-dataset shortcut

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* style nit after copilot review

* fix: bug in dataset root when editing the dataset in place (without setting new_repo_id

* Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos

* Added missing output repo id

* migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again.
Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com>

* added modified suffix in case repo_id is not set in delete_episode

* adding docs for dataset tools

* bump av version and add back time_base assignment

* linter

* modified push_to_hub logic in lerobot_edit_dataset

* fix(progress bar): fixing the progress bar issue in dataset tools

* chore(concatenate): removing no longer needed concatenate_datasets usage

* fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets

* style fix

* refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging

There were three critical bugs in aggregate.py that prevented correct dataset merging:

1. Video file indices: Changed from += to = assignment to correctly reference
   merged video files

2. Video timestamps: Implemented per-source-file offset tracking to maintain
   continuous timestamps when merging split datasets (was causing non-monotonic
   timestamp warnings)

3. File rotation offsets: Store timestamp offsets after rotation decision to
   prevent out-of-bounds frame access (was causing "Invalid frame index" errors
   with small file size limits)

Changes:
- Updated update_meta_data() to apply per-source-file timestamp offsets
- Updated aggregate_videos() to track offsets correctly during file rotation
- Added get_video_duration_in_s import for duration calculation

* Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes

* chore(docs): update merge documentation details

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>

---------

Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
Co-authored-by: Jack Vial <vialjack@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
Michel Aractingi
2025-10-10 12:32:07 +02:00
committed by GitHub
parent 656fc0f059
commit b8f7e401d4
13 changed files with 2593 additions and 30 deletions

View File

@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
pass
def assert_video_timestamps_within_bounds(aggr_ds):
"""Test that all video timestamps are within valid bounds for their respective video files.
This catches bugs where timestamps point to frames beyond the actual video length,
which would cause "Invalid frame index" errors during data loading.
"""
try:
from torchcodec.decoders import VideoDecoder
except ImportError:
return
for ep_idx in range(aggr_ds.num_episodes):
ep = aggr_ds.meta.episodes[ep_idx]
for vid_key in aggr_ds.meta.video_keys:
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
continue
from_frame_idx = round(from_ts * aggr_ds.fps)
to_frame_idx = round(to_ts * aggr_ds.fps)
try:
decoder = VideoDecoder(str(video_path))
num_frames = len(decoder)
# Verify timestamps don't exceed video bounds
assert from_frame_idx >= 0, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
)
assert from_frame_idx < num_frames, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
)
assert to_frame_idx <= num_frames, (
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
)
assert from_frame_idx < to_frame_idx, (
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
)
except Exception as e:
raise AssertionError(
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
) from e
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_video_timestamps_within_bounds(aggr_ds)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
"""Regression test for video timestamp bug when merging datasets.
This test specifically checks that video timestamps are correctly calculated
and accumulated when merging multiple datasets.
"""
datasets = []
for i in range(3):
ds = lerobot_dataset_factory(
root=tmp_path / f"regression_{i}",
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
total_episodes=2,
total_frames=100,
)
datasets.append(ds)
aggregate_datasets(
repo_ids=[ds.repo_id for ds in datasets],
roots=[ds.root for ds in datasets],
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
aggr_root=tmp_path / "regression_aggr",
)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
assert_video_timestamps_within_bounds(aggr_ds)
for i in range(len(aggr_ds)):
item = aggr_ds[i]
for key in aggr_ds.meta.video_keys:
assert key in item, f"Video key {key} missing from item {i}"
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"

View File

@@ -0,0 +1,891 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for dataset tools utilities."""
from unittest.mock import patch
import numpy as np
import pytest
import torch
from lerobot.datasets.dataset_tools import (
add_feature,
delete_episodes,
merge_datasets,
remove_feature,
split_dataset,
)
@pytest.fixture
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
"""Create a sample dataset for testing."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset",
features=features,
)
for ep_idx in range(5):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset.add_frame(frame)
dataset.save_episode()
return dataset
def test_delete_single_episode(sample_dataset, tmp_path):
"""Test deleting a single episode."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[2],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 4
assert new_dataset.meta.total_frames == 40
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2, 3}
assert len(new_dataset) == 40
def test_delete_multiple_episodes(sample_dataset, tmp_path):
"""Test deleting multiple episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[1, 3],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 3
assert new_dataset.meta.total_frames == 30
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
assert episode_indices == {0, 1, 2}
def test_delete_invalid_episodes(sample_dataset, tmp_path):
"""Test error handling for invalid episode indices."""
with pytest.raises(ValueError, match="Invalid episode indices"):
delete_episodes(
sample_dataset,
episode_indices=[10, 20],
output_dir=tmp_path / "filtered",
)
def test_delete_all_episodes(sample_dataset, tmp_path):
"""Test error when trying to delete all episodes."""
with pytest.raises(ValueError, match="Cannot delete all episodes"):
delete_episodes(
sample_dataset,
episode_indices=list(range(5)),
output_dir=tmp_path / "filtered",
)
def test_delete_empty_list(sample_dataset, tmp_path):
"""Test error when no episodes specified."""
with pytest.raises(ValueError, match="No episodes to delete"):
delete_episodes(
sample_dataset,
episode_indices=[],
output_dir=tmp_path / "filtered",
)
def test_split_by_episodes(sample_dataset, tmp_path):
"""Test splitting dataset by specific episode indices."""
splits = {
"train": [0, 1, 2],
"val": [3, 4],
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
if "train" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_train")
elif "val" in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_val")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert set(result.keys()) == {"train", "val"}
assert result["train"].meta.total_episodes == 3
assert result["train"].meta.total_frames == 30
assert result["val"].meta.total_episodes == 2
assert result["val"].meta.total_frames == 20
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
assert train_episodes == {0, 1, 2}
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
assert val_episodes == {0, 1}
def test_split_by_fractions(sample_dataset, tmp_path):
"""Test splitting dataset by fractions."""
splits = {
"train": 0.6,
"val": 0.4,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert result["train"].meta.total_episodes == 3
assert result["val"].meta.total_episodes == 2
def test_split_overlapping_episodes(sample_dataset, tmp_path):
"""Test error when episodes appear in multiple splits."""
splits = {
"train": [0, 1, 2],
"val": [2, 3, 4],
}
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_invalid_fractions(sample_dataset, tmp_path):
"""Test error when fractions sum to more than 1."""
splits = {
"train": 0.7,
"val": 0.5,
}
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
def test_split_empty(sample_dataset, tmp_path):
"""Test error with empty splits."""
with pytest.raises(ValueError, match="No splits provided"):
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test merging two datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset2",
features=features,
)
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
[sample_dataset, dataset2],
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.total_episodes == 8 # 5 + 3
assert merged.meta.total_frames == 80 # 50 + 30
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
assert episode_indices == list(range(8))
def test_merge_empty_list(tmp_path):
"""Test error when merging empty list."""
with pytest.raises(ValueError, match="No datasets to merge"):
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
def test_add_feature_with_values(sample_dataset, tmp_path):
"""Test adding a feature with pre-computed values."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert "reward" in new_dataset.meta.features
assert new_dataset.meta.features["reward"] == feature_info
assert len(new_dataset) == num_frames
sample_item = new_dataset[0]
assert "reward" in sample_item
assert isinstance(sample_item["reward"], torch.Tensor)
def test_add_feature_with_callable(sample_dataset, tmp_path):
"""Test adding a feature with a callable."""
def compute_reward(frame_dict, episode_idx, frame_idx):
return float(episode_idx * 10 + frame_idx)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=compute_reward,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert "reward" in new_dataset.meta.features
items = [new_dataset[i] for i in range(10)]
first_episode_items = [item for item in items if item["episode_index"] == 0]
assert len(first_episode_items) == 10
first_frame = first_episode_items[0]
assert first_frame["frame_index"] == 0
assert float(first_frame["reward"]) == 0.0
def test_add_existing_feature(sample_dataset, tmp_path):
"""Test error when adding an existing feature."""
feature_info = {"dtype": "float32", "shape": (1,)}
with pytest.raises(ValueError, match="Feature 'action' already exists"):
add_feature(
sample_dataset,
feature_name="action",
feature_values=np.zeros(50),
feature_info=feature_info,
output_dir=tmp_path / "modified",
)
def test_add_feature_invalid_info(sample_dataset, tmp_path):
"""Test error with invalid feature info."""
with pytest.raises(ValueError, match="feature_info must contain keys"):
add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.zeros(50),
feature_info={"dtype": "float32"},
output_dir=tmp_path / "modified",
)
def test_remove_single_feature(sample_dataset, tmp_path):
"""Test removing a single feature."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
dataset_without_reward = remove_feature(
dataset_with_reward,
feature_names="reward",
output_dir=tmp_path / "without_reward",
)
assert "reward" not in dataset_without_reward.meta.features
sample_item = dataset_without_reward[0]
assert "reward" not in sample_item
def test_remove_multiple_features(sample_dataset, tmp_path):
"""Test removing multiple features at once."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = sample_dataset
for feature_name in ["reward", "success"]:
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
dataset = add_feature(
dataset,
feature_name=feature_name,
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / f"with_{feature_name}",
)
dataset_clean = remove_feature(
dataset,
feature_names=["reward", "success"],
output_dir=tmp_path / "clean",
)
assert "reward" not in dataset_clean.meta.features
assert "success" not in dataset_clean.meta.features
def test_remove_nonexistent_feature(sample_dataset, tmp_path):
"""Test error when removing non-existent feature."""
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
remove_feature(
sample_dataset,
feature_names="nonexistent",
output_dir=tmp_path / "modified",
)
def test_remove_required_feature(sample_dataset, tmp_path):
"""Test error when trying to remove required features."""
with pytest.raises(ValueError, match="Cannot remove required features"):
remove_feature(
sample_dataset,
feature_names="timestamp",
output_dir=tmp_path / "modified",
)
def test_remove_camera_feature(sample_dataset, tmp_path):
"""Test removing a camera feature."""
camera_keys = sample_dataset.meta.camera_keys
if not camera_keys:
pytest.skip("No camera keys in dataset")
camera_to_remove = camera_keys[0]
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
dataset_without_camera = remove_feature(
sample_dataset,
feature_names=camera_to_remove,
output_dir=tmp_path / "without_camera",
)
assert camera_to_remove not in dataset_without_camera.meta.features
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
sample_item = dataset_without_camera[0]
assert camera_to_remove not in sample_item
def test_complex_workflow_integration(sample_dataset, tmp_path):
"""Test a complex workflow combining multiple operations."""
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info={"dtype": "float32", "shape": (1,), "names": None},
output_dir=tmp_path / "step1",
)
dataset = delete_episodes(
dataset,
episode_indices=[2],
output_dir=tmp_path / "step2",
)
splits = split_dataset(
dataset,
splits={"train": 0.75, "val": 0.25},
output_dir=tmp_path / "step3",
)
merged = merge_datasets(
list(splits.values()),
output_repo_id="final_dataset",
output_dir=tmp_path / "step4",
)
assert merged.meta.total_episodes == 4
assert merged.meta.total_frames == 40
assert "reward" in merged.meta.features
assert len(merged) == 40
sample_item = merged[0]
assert "reward" in sample_item
def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
"""Test that deleting episodes preserves statistics correctly."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[2],
output_dir=output_dir,
)
assert new_dataset.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in new_dataset.meta.stats
assert "mean" in new_dataset.meta.stats[feature]
assert "std" in new_dataset.meta.stats[feature]
def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
"""Test that tasks are preserved correctly after deletion."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[0],
output_dir=output_dir,
)
assert new_dataset.meta.tasks is not None
assert len(new_dataset.meta.tasks) == 2
tasks_in_dataset = {str(item["task"]) for item in new_dataset}
assert len(tasks_in_dataset) > 0
def test_split_three_ways(sample_dataset, tmp_path):
"""Test splitting dataset into three splits."""
splits = {
"train": 0.6,
"val": 0.2,
"test": 0.2,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
assert set(result.keys()) == {"train", "val", "test"}
assert result["train"].meta.total_episodes == 3
assert result["val"].meta.total_episodes == 1
assert result["test"].meta.total_episodes == 1
total_frames = sum(ds.meta.total_frames for ds in result.values())
assert total_frames == sample_dataset.meta.total_frames
def test_split_preserves_stats(sample_dataset, tmp_path):
"""Test that statistics are preserved when splitting."""
splits = {"train": [0, 1, 2], "val": [3, 4]}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
for split_ds in result.values():
assert split_ds.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in split_ds.meta.stats
assert "mean" in split_ds.meta.stats[feature]
assert "std" in split_ds.meta.stats[feature]
def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test merging three datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
datasets = [sample_dataset]
for i in range(2):
dataset = empty_lerobot_dataset_factory(
root=tmp_path / f"test_dataset{i + 2}",
features=features,
)
for ep_idx in range(2):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx}",
}
dataset.add_frame(frame)
dataset.save_episode()
datasets.append(dataset)
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
datasets,
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.total_episodes == 9
assert merged.meta.total_frames == 90
def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
"""Test that statistics are computed for merged datasets."""
features = {
"action": {"dtype": "float32", "shape": (6,), "names": None},
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
}
dataset2 = empty_lerobot_dataset_factory(
root=tmp_path / "test_dataset2",
features=features,
)
for ep_idx in range(3):
for _ in range(10):
frame = {
"action": np.random.randn(6).astype(np.float32),
"observation.state": np.random.randn(4).astype(np.float32),
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
"task": f"task_{ep_idx % 2}",
}
dataset2.add_frame(frame)
dataset2.save_episode()
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
merged = merge_datasets(
[sample_dataset, dataset2],
output_repo_id="merged_dataset",
output_dir=tmp_path / "merged_dataset",
)
assert merged.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in merged.meta.stats
assert "mean" in merged.meta.stats[feature]
assert "std" in merged.meta.stats[feature]
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
"""Test that adding a feature preserves existing stats."""
num_frames = sample_dataset.meta.total_frames
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
feature_info = {
"dtype": "float32",
"shape": (1,),
"names": None,
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
new_dataset = add_feature(
sample_dataset,
feature_name="reward",
feature_values=reward_values,
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
assert new_dataset.meta.stats is not None
for feature in ["action", "observation.state"]:
assert feature in new_dataset.meta.stats
assert "mean" in new_dataset.meta.stats[feature]
assert "std" in new_dataset.meta.stats[feature]
def test_remove_feature_updates_stats(sample_dataset, tmp_path):
"""Test that removing a feature removes it from stats."""
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
dataset_with_reward = add_feature(
sample_dataset,
feature_name="reward",
feature_values=np.random.randn(50, 1).astype(np.float32),
feature_info=feature_info,
output_dir=tmp_path / "with_reward",
)
dataset_without_reward = remove_feature(
dataset_with_reward,
feature_names="reward",
output_dir=tmp_path / "without_reward",
)
if dataset_without_reward.meta.stats:
assert "reward" not in dataset_without_reward.meta.stats
def test_delete_consecutive_episodes(sample_dataset, tmp_path):
"""Test deleting consecutive episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[1, 2, 3],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 2
assert new_dataset.meta.total_frames == 20
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
assert episode_indices == [0, 1]
def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
"""Test deleting first and last episodes."""
output_dir = tmp_path / "filtered"
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(output_dir)
new_dataset = delete_episodes(
sample_dataset,
episode_indices=[0, 4],
output_dir=output_dir,
)
assert new_dataset.meta.total_episodes == 3
assert new_dataset.meta.total_frames == 30
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
assert episode_indices == [0, 1, 2]
def test_split_all_episodes_assigned(sample_dataset, tmp_path):
"""Test that all episodes can be explicitly assigned to splits."""
splits = {
"split1": [0, 1],
"split2": [2, 3],
"split3": [4],
}
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
def mock_snapshot(repo_id, **kwargs):
for split_name in splits:
if split_name in repo_id:
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
return str(kwargs.get("local_dir", tmp_path))
mock_snapshot_download.side_effect = mock_snapshot
result = split_dataset(
sample_dataset,
splits=splits,
output_dir=tmp_path,
)
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
assert total_episodes == sample_dataset.meta.total_episodes