* feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> --------- Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com> Co-authored-by: Jack Vial <vialjack@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
892 lines
31 KiB
Python
892 lines
31 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for dataset tools utilities."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.datasets.dataset_tools import (
|
|
add_feature,
|
|
delete_episodes,
|
|
merge_datasets,
|
|
remove_feature,
|
|
split_dataset,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|
"""Create a sample dataset for testing."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(5):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset.add_frame(frame)
|
|
dataset.save_episode()
|
|
|
|
return dataset
|
|
|
|
|
|
def test_delete_single_episode(sample_dataset, tmp_path):
|
|
"""Test deleting a single episode."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[2],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 4
|
|
assert new_dataset.meta.total_frames == 40
|
|
|
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
|
assert episode_indices == {0, 1, 2, 3}
|
|
|
|
assert len(new_dataset) == 40
|
|
|
|
|
|
def test_delete_multiple_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting multiple episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[1, 3],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 3
|
|
assert new_dataset.meta.total_frames == 30
|
|
|
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
|
assert episode_indices == {0, 1, 2}
|
|
|
|
|
|
def test_delete_invalid_episodes(sample_dataset, tmp_path):
|
|
"""Test error handling for invalid episode indices."""
|
|
with pytest.raises(ValueError, match="Invalid episode indices"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[10, 20],
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_delete_all_episodes(sample_dataset, tmp_path):
|
|
"""Test error when trying to delete all episodes."""
|
|
with pytest.raises(ValueError, match="Cannot delete all episodes"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=list(range(5)),
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_delete_empty_list(sample_dataset, tmp_path):
|
|
"""Test error when no episodes specified."""
|
|
with pytest.raises(ValueError, match="No episodes to delete"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[],
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_split_by_episodes(sample_dataset, tmp_path):
|
|
"""Test splitting dataset by specific episode indices."""
|
|
splits = {
|
|
"train": [0, 1, 2],
|
|
"val": [3, 4],
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
if "train" in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_train")
|
|
elif "val" in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_val")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert set(result.keys()) == {"train", "val"}
|
|
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["train"].meta.total_frames == 30
|
|
|
|
assert result["val"].meta.total_episodes == 2
|
|
assert result["val"].meta.total_frames == 20
|
|
|
|
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
|
|
assert train_episodes == {0, 1, 2}
|
|
|
|
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
|
|
assert val_episodes == {0, 1}
|
|
|
|
|
|
def test_split_by_fractions(sample_dataset, tmp_path):
|
|
"""Test splitting dataset by fractions."""
|
|
splits = {
|
|
"train": 0.6,
|
|
"val": 0.4,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["val"].meta.total_episodes == 2
|
|
|
|
|
|
def test_split_overlapping_episodes(sample_dataset, tmp_path):
|
|
"""Test error when episodes appear in multiple splits."""
|
|
splits = {
|
|
"train": [0, 1, 2],
|
|
"val": [2, 3, 4],
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
|
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
|
|
|
|
|
def test_split_invalid_fractions(sample_dataset, tmp_path):
|
|
"""Test error when fractions sum to more than 1."""
|
|
splits = {
|
|
"train": 0.7,
|
|
"val": 0.5,
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
|
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
|
|
|
|
|
def test_split_empty(sample_dataset, tmp_path):
|
|
"""Test error with empty splits."""
|
|
with pytest.raises(ValueError, match="No splits provided"):
|
|
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
|
|
|
|
|
|
def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test merging two datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset2 = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset2",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(3):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset2.add_frame(frame)
|
|
dataset2.save_episode()
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
[sample_dataset, dataset2],
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 8 # 5 + 3
|
|
assert merged.meta.total_frames == 80 # 50 + 30
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
|
|
assert episode_indices == list(range(8))
|
|
|
|
|
|
def test_merge_empty_list(tmp_path):
|
|
"""Test error when merging empty list."""
|
|
with pytest.raises(ValueError, match="No datasets to merge"):
|
|
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
|
|
|
|
|
def test_add_feature_with_values(sample_dataset, tmp_path):
|
|
"""Test adding a feature with pre-computed values."""
|
|
num_frames = sample_dataset.meta.total_frames
|
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=reward_values,
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert "reward" in new_dataset.meta.features
|
|
assert new_dataset.meta.features["reward"] == feature_info
|
|
|
|
assert len(new_dataset) == num_frames
|
|
sample_item = new_dataset[0]
|
|
assert "reward" in sample_item
|
|
assert isinstance(sample_item["reward"], torch.Tensor)
|
|
|
|
|
|
def test_add_feature_with_callable(sample_dataset, tmp_path):
|
|
"""Test adding a feature with a callable."""
|
|
|
|
def compute_reward(frame_dict, episode_idx, frame_idx):
|
|
return float(episode_idx * 10 + frame_idx)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=compute_reward,
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert "reward" in new_dataset.meta.features
|
|
|
|
items = [new_dataset[i] for i in range(10)]
|
|
first_episode_items = [item for item in items if item["episode_index"] == 0]
|
|
assert len(first_episode_items) == 10
|
|
|
|
first_frame = first_episode_items[0]
|
|
assert first_frame["frame_index"] == 0
|
|
assert float(first_frame["reward"]) == 0.0
|
|
|
|
|
|
def test_add_existing_feature(sample_dataset, tmp_path):
|
|
"""Test error when adding an existing feature."""
|
|
feature_info = {"dtype": "float32", "shape": (1,)}
|
|
|
|
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
|
add_feature(
|
|
sample_dataset,
|
|
feature_name="action",
|
|
feature_values=np.zeros(50),
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_add_feature_invalid_info(sample_dataset, tmp_path):
|
|
"""Test error with invalid feature info."""
|
|
with pytest.raises(ValueError, match="feature_info must contain keys"):
|
|
add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=np.zeros(50),
|
|
feature_info={"dtype": "float32"},
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_single_feature(sample_dataset, tmp_path):
|
|
"""Test removing a single feature."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset_with_reward = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
dataset_without_reward = remove_feature(
|
|
dataset_with_reward,
|
|
feature_names="reward",
|
|
output_dir=tmp_path / "without_reward",
|
|
)
|
|
|
|
assert "reward" not in dataset_without_reward.meta.features
|
|
|
|
sample_item = dataset_without_reward[0]
|
|
assert "reward" not in sample_item
|
|
|
|
|
|
def test_remove_multiple_features(sample_dataset, tmp_path):
|
|
"""Test removing multiple features at once."""
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset = sample_dataset
|
|
for feature_name in ["reward", "success"]:
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
dataset = add_feature(
|
|
dataset,
|
|
feature_name=feature_name,
|
|
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / f"with_{feature_name}",
|
|
)
|
|
|
|
dataset_clean = remove_feature(
|
|
dataset,
|
|
feature_names=["reward", "success"],
|
|
output_dir=tmp_path / "clean",
|
|
)
|
|
|
|
assert "reward" not in dataset_clean.meta.features
|
|
assert "success" not in dataset_clean.meta.features
|
|
|
|
|
|
def test_remove_nonexistent_feature(sample_dataset, tmp_path):
|
|
"""Test error when removing non-existent feature."""
|
|
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
|
|
remove_feature(
|
|
sample_dataset,
|
|
feature_names="nonexistent",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_required_feature(sample_dataset, tmp_path):
|
|
"""Test error when trying to remove required features."""
|
|
with pytest.raises(ValueError, match="Cannot remove required features"):
|
|
remove_feature(
|
|
sample_dataset,
|
|
feature_names="timestamp",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_camera_feature(sample_dataset, tmp_path):
|
|
"""Test removing a camera feature."""
|
|
camera_keys = sample_dataset.meta.camera_keys
|
|
if not camera_keys:
|
|
pytest.skip("No camera keys in dataset")
|
|
|
|
camera_to_remove = camera_keys[0]
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
|
|
|
|
dataset_without_camera = remove_feature(
|
|
sample_dataset,
|
|
feature_names=camera_to_remove,
|
|
output_dir=tmp_path / "without_camera",
|
|
)
|
|
|
|
assert camera_to_remove not in dataset_without_camera.meta.features
|
|
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
|
|
|
|
sample_item = dataset_without_camera[0]
|
|
assert camera_to_remove not in sample_item
|
|
|
|
|
|
def test_complex_workflow_integration(sample_dataset, tmp_path):
|
|
"""Test a complex workflow combining multiple operations."""
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
|
feature_info={"dtype": "float32", "shape": (1,), "names": None},
|
|
output_dir=tmp_path / "step1",
|
|
)
|
|
|
|
dataset = delete_episodes(
|
|
dataset,
|
|
episode_indices=[2],
|
|
output_dir=tmp_path / "step2",
|
|
)
|
|
|
|
splits = split_dataset(
|
|
dataset,
|
|
splits={"train": 0.75, "val": 0.25},
|
|
output_dir=tmp_path / "step3",
|
|
)
|
|
|
|
merged = merge_datasets(
|
|
list(splits.values()),
|
|
output_repo_id="final_dataset",
|
|
output_dir=tmp_path / "step4",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 4
|
|
assert merged.meta.total_frames == 40
|
|
assert "reward" in merged.meta.features
|
|
|
|
assert len(merged) == 40
|
|
sample_item = merged[0]
|
|
assert "reward" in sample_item
|
|
|
|
|
|
def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
|
|
"""Test that deleting episodes preserves statistics correctly."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[2],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in new_dataset.meta.stats
|
|
assert "mean" in new_dataset.meta.stats[feature]
|
|
assert "std" in new_dataset.meta.stats[feature]
|
|
|
|
|
|
def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
|
|
"""Test that tasks are preserved correctly after deletion."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[0],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.tasks is not None
|
|
assert len(new_dataset.meta.tasks) == 2
|
|
|
|
tasks_in_dataset = {str(item["task"]) for item in new_dataset}
|
|
assert len(tasks_in_dataset) > 0
|
|
|
|
|
|
def test_split_three_ways(sample_dataset, tmp_path):
|
|
"""Test splitting dataset into three splits."""
|
|
splits = {
|
|
"train": 0.6,
|
|
"val": 0.2,
|
|
"test": 0.2,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert set(result.keys()) == {"train", "val", "test"}
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["val"].meta.total_episodes == 1
|
|
assert result["test"].meta.total_episodes == 1
|
|
|
|
total_frames = sum(ds.meta.total_frames for ds in result.values())
|
|
assert total_frames == sample_dataset.meta.total_frames
|
|
|
|
|
|
def test_split_preserves_stats(sample_dataset, tmp_path):
|
|
"""Test that statistics are preserved when splitting."""
|
|
splits = {"train": [0, 1, 2], "val": [3, 4]}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
for split_ds in result.values():
|
|
assert split_ds.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in split_ds.meta.stats
|
|
assert "mean" in split_ds.meta.stats[feature]
|
|
assert "std" in split_ds.meta.stats[feature]
|
|
|
|
|
|
def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test merging three datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
datasets = [sample_dataset]
|
|
|
|
for i in range(2):
|
|
dataset = empty_lerobot_dataset_factory(
|
|
root=tmp_path / f"test_dataset{i + 2}",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(2):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx}",
|
|
}
|
|
dataset.add_frame(frame)
|
|
dataset.save_episode()
|
|
|
|
datasets.append(dataset)
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
datasets,
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 9
|
|
assert merged.meta.total_frames == 90
|
|
|
|
|
|
def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test that statistics are computed for merged datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset2 = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset2",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(3):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset2.add_frame(frame)
|
|
dataset2.save_episode()
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
[sample_dataset, dataset2],
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in merged.meta.stats
|
|
assert "mean" in merged.meta.stats[feature]
|
|
assert "std" in merged.meta.stats[feature]
|
|
|
|
|
|
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
|
"""Test that adding a feature preserves existing stats."""
|
|
num_frames = sample_dataset.meta.total_frames
|
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=reward_values,
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert new_dataset.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in new_dataset.meta.stats
|
|
assert "mean" in new_dataset.meta.stats[feature]
|
|
assert "std" in new_dataset.meta.stats[feature]
|
|
|
|
|
|
def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
|
"""Test that removing a feature removes it from stats."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset_with_reward = add_feature(
|
|
sample_dataset,
|
|
feature_name="reward",
|
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
|
feature_info=feature_info,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
dataset_without_reward = remove_feature(
|
|
dataset_with_reward,
|
|
feature_names="reward",
|
|
output_dir=tmp_path / "without_reward",
|
|
)
|
|
|
|
if dataset_without_reward.meta.stats:
|
|
assert "reward" not in dataset_without_reward.meta.stats
|
|
|
|
|
|
def test_delete_consecutive_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting consecutive episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[1, 2, 3],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 2
|
|
assert new_dataset.meta.total_frames == 20
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
|
assert episode_indices == [0, 1]
|
|
|
|
|
|
def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting first and last episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[0, 4],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 3
|
|
assert new_dataset.meta.total_frames == 30
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
|
assert episode_indices == [0, 1, 2]
|
|
|
|
|
|
def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
|
"""Test that all episodes can be explicitly assigned to splits."""
|
|
splits = {
|
|
"split1": [0, 1],
|
|
"split2": [2, 3],
|
|
"split3": [4],
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
|
|
assert total_episodes == sample_dataset.meta.total_episodes
|