* make add_feature take multiple features at a time and rename to add_features * - New function: modify_features that was a combination of remove features and add features. - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times
1050 lines
37 KiB
Python
1050 lines
37 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for dataset tools utilities."""
|
|
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from lerobot.datasets.dataset_tools import (
|
|
add_features,
|
|
delete_episodes,
|
|
merge_datasets,
|
|
modify_features,
|
|
remove_feature,
|
|
split_dataset,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
|
|
"""Create a sample dataset for testing."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(5):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset.add_frame(frame)
|
|
dataset.save_episode()
|
|
|
|
dataset.finalize()
|
|
return dataset
|
|
|
|
|
|
def test_delete_single_episode(sample_dataset, tmp_path):
|
|
"""Test deleting a single episode."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[2],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 4
|
|
assert new_dataset.meta.total_frames == 40
|
|
|
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
|
assert episode_indices == {0, 1, 2, 3}
|
|
|
|
assert len(new_dataset) == 40
|
|
|
|
|
|
def test_delete_multiple_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting multiple episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[1, 3],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 3
|
|
assert new_dataset.meta.total_frames == 30
|
|
|
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
|
assert episode_indices == {0, 1, 2}
|
|
|
|
|
|
def test_delete_invalid_episodes(sample_dataset, tmp_path):
|
|
"""Test error handling for invalid episode indices."""
|
|
with pytest.raises(ValueError, match="Invalid episode indices"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[10, 20],
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_delete_all_episodes(sample_dataset, tmp_path):
|
|
"""Test error when trying to delete all episodes."""
|
|
with pytest.raises(ValueError, match="Cannot delete all episodes"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=list(range(5)),
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_delete_empty_list(sample_dataset, tmp_path):
|
|
"""Test error when no episodes specified."""
|
|
with pytest.raises(ValueError, match="No episodes to delete"):
|
|
delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[],
|
|
output_dir=tmp_path / "filtered",
|
|
)
|
|
|
|
|
|
def test_split_by_episodes(sample_dataset, tmp_path):
|
|
"""Test splitting dataset by specific episode indices."""
|
|
splits = {
|
|
"train": [0, 1, 2],
|
|
"val": [3, 4],
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
if "train" in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_train")
|
|
elif "val" in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_val")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert set(result.keys()) == {"train", "val"}
|
|
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["train"].meta.total_frames == 30
|
|
|
|
assert result["val"].meta.total_episodes == 2
|
|
assert result["val"].meta.total_frames == 20
|
|
|
|
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
|
|
assert train_episodes == {0, 1, 2}
|
|
|
|
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
|
|
assert val_episodes == {0, 1}
|
|
|
|
|
|
def test_split_by_fractions(sample_dataset, tmp_path):
|
|
"""Test splitting dataset by fractions."""
|
|
splits = {
|
|
"train": 0.6,
|
|
"val": 0.4,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["val"].meta.total_episodes == 2
|
|
|
|
|
|
def test_split_overlapping_episodes(sample_dataset, tmp_path):
|
|
"""Test error when episodes appear in multiple splits."""
|
|
splits = {
|
|
"train": [0, 1, 2],
|
|
"val": [2, 3, 4],
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
|
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
|
|
|
|
|
def test_split_invalid_fractions(sample_dataset, tmp_path):
|
|
"""Test error when fractions sum to more than 1."""
|
|
splits = {
|
|
"train": 0.7,
|
|
"val": 0.5,
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
|
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
|
|
|
|
|
def test_split_empty(sample_dataset, tmp_path):
|
|
"""Test error with empty splits."""
|
|
with pytest.raises(ValueError, match="No splits provided"):
|
|
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
|
|
|
|
|
|
def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test merging two datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset2 = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset2",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(3):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset2.add_frame(frame)
|
|
dataset2.save_episode()
|
|
dataset2.finalize()
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
[sample_dataset, dataset2],
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 8 # 5 + 3
|
|
assert merged.meta.total_frames == 80 # 50 + 30
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
|
|
assert episode_indices == list(range(8))
|
|
|
|
|
|
def test_merge_empty_list(tmp_path):
|
|
"""Test error when merging empty list."""
|
|
with pytest.raises(ValueError, match="No datasets to merge"):
|
|
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
|
|
|
|
|
def test_add_features_with_values(sample_dataset, tmp_path):
|
|
"""Test adding a feature with pre-computed values."""
|
|
num_frames = sample_dataset.meta.total_frames
|
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
features = {
|
|
"reward": (reward_values, feature_info),
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_features(
|
|
dataset=sample_dataset,
|
|
features=features,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert "reward" in new_dataset.meta.features
|
|
assert new_dataset.meta.features["reward"] == feature_info
|
|
|
|
assert len(new_dataset) == num_frames
|
|
sample_item = new_dataset[0]
|
|
assert "reward" in sample_item
|
|
assert isinstance(sample_item["reward"], torch.Tensor)
|
|
|
|
|
|
def test_add_features_with_callable(sample_dataset, tmp_path):
|
|
"""Test adding a feature with a callable."""
|
|
|
|
def compute_reward(frame_dict, episode_idx, frame_idx):
|
|
return float(episode_idx * 10 + frame_idx)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
features = {
|
|
"reward": (compute_reward, feature_info),
|
|
}
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_features(
|
|
dataset=sample_dataset,
|
|
features=features,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert "reward" in new_dataset.meta.features
|
|
|
|
items = [new_dataset[i] for i in range(10)]
|
|
first_episode_items = [item for item in items if item["episode_index"] == 0]
|
|
assert len(first_episode_items) == 10
|
|
|
|
first_frame = first_episode_items[0]
|
|
assert first_frame["frame_index"] == 0
|
|
assert float(first_frame["reward"]) == 0.0
|
|
|
|
|
|
def test_add_existing_feature(sample_dataset, tmp_path):
|
|
"""Test error when adding an existing feature."""
|
|
feature_info = {"dtype": "float32", "shape": (1,)}
|
|
features = {
|
|
"action": (np.zeros(50), feature_info),
|
|
}
|
|
|
|
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
|
add_features(
|
|
dataset=sample_dataset,
|
|
features=features,
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_add_feature_invalid_info(sample_dataset, tmp_path):
|
|
"""Test error with invalid feature info."""
|
|
with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"):
|
|
add_features(
|
|
dataset=sample_dataset,
|
|
features={
|
|
"reward": (np.zeros(50), {"dtype": "float32"}),
|
|
},
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_modify_features_add_and_remove(sample_dataset, tmp_path):
|
|
"""Test modifying features by adding and removing simultaneously."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
|
|
|
# First add a feature we'll later remove
|
|
dataset_with_reward = add_features(
|
|
sample_dataset,
|
|
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
# Now use modify_features to add "success" and remove "reward" in one pass
|
|
modified_dataset = modify_features(
|
|
dataset_with_reward,
|
|
add_features={
|
|
"success": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
|
},
|
|
remove_features="reward",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
assert "success" in modified_dataset.meta.features
|
|
assert "reward" not in modified_dataset.meta.features
|
|
assert len(modified_dataset) == 50
|
|
|
|
|
|
def test_modify_features_only_add(sample_dataset, tmp_path):
|
|
"""Test that modify_features works with only add_features."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
|
|
|
modified_dataset = modify_features(
|
|
sample_dataset,
|
|
add_features={
|
|
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
|
},
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
assert "reward" in modified_dataset.meta.features
|
|
assert len(modified_dataset) == 50
|
|
|
|
|
|
def test_modify_features_only_remove(sample_dataset, tmp_path):
|
|
"""Test that modify_features works with only remove_features."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset_with_reward = add_features(
|
|
sample_dataset,
|
|
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
modified_dataset = modify_features(
|
|
dataset_with_reward,
|
|
remove_features="reward",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
assert "reward" not in modified_dataset.meta.features
|
|
|
|
|
|
def test_modify_features_no_changes(sample_dataset, tmp_path):
|
|
"""Test error when modify_features is called with no changes."""
|
|
with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"):
|
|
modify_features(
|
|
sample_dataset,
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_single_feature(sample_dataset, tmp_path):
|
|
"""Test removing a single feature."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
features = {
|
|
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
|
}
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset_with_reward = add_features(
|
|
dataset=sample_dataset,
|
|
features=features,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
dataset_without_reward = remove_feature(
|
|
dataset_with_reward,
|
|
feature_names="reward",
|
|
output_dir=tmp_path / "without_reward",
|
|
)
|
|
|
|
assert "reward" not in dataset_without_reward.meta.features
|
|
|
|
sample_item = dataset_without_reward[0]
|
|
assert "reward" not in sample_item
|
|
|
|
|
|
def test_remove_multiple_features(sample_dataset, tmp_path):
|
|
"""Test removing multiple features at once."""
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset = sample_dataset
|
|
features = {}
|
|
for feature_name in ["reward", "success"]:
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
features[feature_name] = (
|
|
np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
|
feature_info,
|
|
)
|
|
|
|
dataset_with_features = add_features(
|
|
dataset, features=features, output_dir=tmp_path / "with_features"
|
|
)
|
|
dataset_clean = remove_feature(
|
|
dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean"
|
|
)
|
|
|
|
assert "reward" not in dataset_clean.meta.features
|
|
assert "success" not in dataset_clean.meta.features
|
|
|
|
|
|
def test_remove_nonexistent_feature(sample_dataset, tmp_path):
|
|
"""Test error when removing non-existent feature."""
|
|
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
|
|
remove_feature(
|
|
sample_dataset,
|
|
feature_names="nonexistent",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_required_feature(sample_dataset, tmp_path):
|
|
"""Test error when trying to remove required features."""
|
|
with pytest.raises(ValueError, match="Cannot remove required features"):
|
|
remove_feature(
|
|
sample_dataset,
|
|
feature_names="timestamp",
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
|
|
def test_remove_camera_feature(sample_dataset, tmp_path):
|
|
"""Test removing a camera feature."""
|
|
camera_keys = sample_dataset.meta.camera_keys
|
|
if not camera_keys:
|
|
pytest.skip("No camera keys in dataset")
|
|
|
|
camera_to_remove = camera_keys[0]
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
|
|
|
|
dataset_without_camera = remove_feature(
|
|
sample_dataset,
|
|
feature_names=camera_to_remove,
|
|
output_dir=tmp_path / "without_camera",
|
|
)
|
|
|
|
assert camera_to_remove not in dataset_without_camera.meta.features
|
|
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
|
|
|
|
sample_item = dataset_without_camera[0]
|
|
assert camera_to_remove not in sample_item
|
|
|
|
|
|
def test_complex_workflow_integration(sample_dataset, tmp_path):
|
|
"""Test a complex workflow combining multiple operations."""
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset = add_features(
|
|
sample_dataset,
|
|
features={
|
|
"reward": (
|
|
np.random.randn(50, 1).astype(np.float32),
|
|
{"dtype": "float32", "shape": (1,), "names": None},
|
|
)
|
|
},
|
|
output_dir=tmp_path / "step1",
|
|
)
|
|
|
|
dataset = delete_episodes(
|
|
dataset,
|
|
episode_indices=[2],
|
|
output_dir=tmp_path / "step2",
|
|
)
|
|
|
|
splits = split_dataset(
|
|
dataset,
|
|
splits={"train": 0.75, "val": 0.25},
|
|
output_dir=tmp_path / "step3",
|
|
)
|
|
|
|
merged = merge_datasets(
|
|
list(splits.values()),
|
|
output_repo_id="final_dataset",
|
|
output_dir=tmp_path / "step4",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 4
|
|
assert merged.meta.total_frames == 40
|
|
assert "reward" in merged.meta.features
|
|
|
|
assert len(merged) == 40
|
|
sample_item = merged[0]
|
|
assert "reward" in sample_item
|
|
|
|
|
|
def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
|
|
"""Test that deleting episodes preserves statistics correctly."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[2],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in new_dataset.meta.stats
|
|
assert "mean" in new_dataset.meta.stats[feature]
|
|
assert "std" in new_dataset.meta.stats[feature]
|
|
|
|
|
|
def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
|
|
"""Test that tasks are preserved correctly after deletion."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[0],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.tasks is not None
|
|
assert len(new_dataset.meta.tasks) == 2
|
|
|
|
tasks_in_dataset = {str(item["task"]) for item in new_dataset}
|
|
assert len(tasks_in_dataset) > 0
|
|
|
|
|
|
def test_split_three_ways(sample_dataset, tmp_path):
|
|
"""Test splitting dataset into three splits."""
|
|
splits = {
|
|
"train": 0.6,
|
|
"val": 0.2,
|
|
"test": 0.2,
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
assert set(result.keys()) == {"train", "val", "test"}
|
|
assert result["train"].meta.total_episodes == 3
|
|
assert result["val"].meta.total_episodes == 1
|
|
assert result["test"].meta.total_episodes == 1
|
|
|
|
total_frames = sum(ds.meta.total_frames for ds in result.values())
|
|
assert total_frames == sample_dataset.meta.total_frames
|
|
|
|
|
|
def test_split_preserves_stats(sample_dataset, tmp_path):
|
|
"""Test that statistics are preserved when splitting."""
|
|
splits = {"train": [0, 1, 2], "val": [3, 4]}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
for split_ds in result.values():
|
|
assert split_ds.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in split_ds.meta.stats
|
|
assert "mean" in split_ds.meta.stats[feature]
|
|
assert "std" in split_ds.meta.stats[feature]
|
|
|
|
|
|
def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test merging three datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
datasets = [sample_dataset]
|
|
|
|
for i in range(2):
|
|
dataset = empty_lerobot_dataset_factory(
|
|
root=tmp_path / f"test_dataset{i + 2}",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(2):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx}",
|
|
}
|
|
dataset.add_frame(frame)
|
|
dataset.save_episode()
|
|
dataset.finalize()
|
|
|
|
datasets.append(dataset)
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
datasets,
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.total_episodes == 9
|
|
assert merged.meta.total_frames == 90
|
|
|
|
|
|
def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
|
"""Test that statistics are computed for merged datasets."""
|
|
features = {
|
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
|
}
|
|
|
|
dataset2 = empty_lerobot_dataset_factory(
|
|
root=tmp_path / "test_dataset2",
|
|
features=features,
|
|
)
|
|
|
|
for ep_idx in range(3):
|
|
for _ in range(10):
|
|
frame = {
|
|
"action": np.random.randn(6).astype(np.float32),
|
|
"observation.state": np.random.randn(4).astype(np.float32),
|
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
|
"task": f"task_{ep_idx % 2}",
|
|
}
|
|
dataset2.add_frame(frame)
|
|
dataset2.save_episode()
|
|
dataset2.finalize()
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
|
|
|
merged = merge_datasets(
|
|
[sample_dataset, dataset2],
|
|
output_repo_id="merged_dataset",
|
|
output_dir=tmp_path / "merged_dataset",
|
|
)
|
|
|
|
assert merged.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in merged.meta.stats
|
|
assert "mean" in merged.meta.stats[feature]
|
|
assert "std" in merged.meta.stats[feature]
|
|
|
|
|
|
def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
|
|
"""Test that adding a feature preserves existing stats."""
|
|
num_frames = sample_dataset.meta.total_frames
|
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
|
|
|
feature_info = {
|
|
"dtype": "float32",
|
|
"shape": (1,),
|
|
"names": None,
|
|
}
|
|
features = {
|
|
"reward": (reward_values, feature_info),
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
|
|
|
new_dataset = add_features(
|
|
dataset=sample_dataset,
|
|
features=features,
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
assert new_dataset.meta.stats is not None
|
|
for feature in ["action", "observation.state"]:
|
|
assert feature in new_dataset.meta.stats
|
|
assert "mean" in new_dataset.meta.stats[feature]
|
|
assert "std" in new_dataset.meta.stats[feature]
|
|
|
|
|
|
def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
|
"""Test that removing a feature removes it from stats."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
|
|
|
dataset_with_reward = add_features(
|
|
sample_dataset,
|
|
features={
|
|
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
|
},
|
|
output_dir=tmp_path / "with_reward",
|
|
)
|
|
|
|
dataset_without_reward = remove_feature(
|
|
dataset_with_reward,
|
|
feature_names="reward",
|
|
output_dir=tmp_path / "without_reward",
|
|
)
|
|
|
|
if dataset_without_reward.meta.stats:
|
|
assert "reward" not in dataset_without_reward.meta.stats
|
|
|
|
|
|
def test_delete_consecutive_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting consecutive episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[1, 2, 3],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 2
|
|
assert new_dataset.meta.total_frames == 20
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
|
assert episode_indices == [0, 1]
|
|
|
|
|
|
def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
|
|
"""Test deleting first and last episodes."""
|
|
output_dir = tmp_path / "filtered"
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(output_dir)
|
|
|
|
new_dataset = delete_episodes(
|
|
sample_dataset,
|
|
episode_indices=[0, 4],
|
|
output_dir=output_dir,
|
|
)
|
|
|
|
assert new_dataset.meta.total_episodes == 3
|
|
assert new_dataset.meta.total_frames == 30
|
|
|
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
|
assert episode_indices == [0, 1, 2]
|
|
|
|
|
|
def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
|
"""Test that all episodes can be explicitly assigned to splits."""
|
|
splits = {
|
|
"split1": [0, 1],
|
|
"split2": [2, 3],
|
|
"split3": [4],
|
|
}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
for split_name in splits:
|
|
if split_name in repo_id:
|
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
|
return str(kwargs.get("local_dir", tmp_path))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
result = split_dataset(
|
|
sample_dataset,
|
|
splits=splits,
|
|
output_dir=tmp_path,
|
|
)
|
|
|
|
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
|
|
assert total_episodes == sample_dataset.meta.total_episodes
|
|
|
|
|
|
def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
|
"""Test that modifying features preserves chunk_idx and file_idx from source dataset."""
|
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
|
|
def mock_snapshot(repo_id, **kwargs):
|
|
return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1]))
|
|
|
|
mock_snapshot_download.side_effect = mock_snapshot
|
|
|
|
# First split the dataset to create a non-zero starting chunk/file structure
|
|
splits = split_dataset(
|
|
sample_dataset,
|
|
splits={"train": [0, 1, 2], "val": [3, 4]},
|
|
output_dir=tmp_path / "splits",
|
|
)
|
|
|
|
train_dataset = splits["train"]
|
|
|
|
# Get original chunk/file indices from first episode
|
|
if train_dataset.meta.episodes is None:
|
|
from lerobot.datasets.utils import load_episodes
|
|
|
|
train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
|
|
original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
|
|
original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes]
|
|
|
|
# Now add a feature to the split dataset
|
|
modified_dataset = add_features(
|
|
train_dataset,
|
|
features={
|
|
"reward": (
|
|
np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32),
|
|
feature_info,
|
|
),
|
|
},
|
|
output_dir=tmp_path / "modified",
|
|
)
|
|
|
|
# Check that chunk/file indices are preserved
|
|
if modified_dataset.meta.episodes is None:
|
|
from lerobot.datasets.utils import load_episodes
|
|
|
|
modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
|
|
new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
|
|
new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes]
|
|
|
|
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
|
|
assert new_file_indices == original_file_indices, "File indices should be preserved"
|
|
assert "reward" in modified_dataset.meta.features
|