Feat/expand add features (#2202)
* make add_feature take multiple features at a time and rename to add_features * - New function: modify_features that was a combination of remove features and add features. - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times
This commit is contained in:
@@ -22,9 +22,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
add_features,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -292,7 +293,7 @@ def test_merge_empty_list(tmp_path):
|
||||
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
||||
|
||||
|
||||
def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
def test_add_features_with_values(sample_dataset, tmp_path):
|
||||
"""Test adding a feature with pre-computed values."""
|
||||
num_frames = sample_dataset.meta.total_frames
|
||||
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||
@@ -302,6 +303,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
features = {
|
||||
"reward": (reward_values, feature_info),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
@@ -310,11 +314,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -327,7 +329,7 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
assert isinstance(sample_item["reward"], torch.Tensor)
|
||||
|
||||
|
||||
def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
def test_add_features_with_callable(sample_dataset, tmp_path):
|
||||
"""Test adding a feature with a callable."""
|
||||
|
||||
def compute_reward(frame_dict, episode_idx, frame_idx):
|
||||
@@ -338,7 +340,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
features = {
|
||||
"reward": (compute_reward, feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
@@ -346,11 +350,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=compute_reward,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -368,31 +370,88 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
def test_add_existing_feature(sample_dataset, tmp_path):
|
||||
"""Test error when adding an existing feature."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,)}
|
||||
features = {
|
||||
"action": (np.zeros(50), feature_info),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="action",
|
||||
feature_values=np.zeros(50),
|
||||
feature_info=feature_info,
|
||||
add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_add_feature_invalid_info(sample_dataset, tmp_path):
|
||||
"""Test error with invalid feature info."""
|
||||
with pytest.raises(ValueError, match="feature_info must contain keys"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.zeros(50),
|
||||
feature_info={"dtype": "float32"},
|
||||
with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"):
|
||||
add_features(
|
||||
dataset=sample_dataset,
|
||||
features={
|
||||
"reward": (np.zeros(50), {"dtype": "float32"}),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
"""Test removing a single feature."""
|
||||
def test_modify_features_add_and_remove(sample_dataset, tmp_path):
|
||||
"""Test modifying features by adding and removing simultaneously."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
|
||||
# First add a feature we'll later remove
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
# Now use modify_features to add "success" and remove "reward" in one pass
|
||||
modified_dataset = modify_features(
|
||||
dataset_with_reward,
|
||||
add_features={
|
||||
"success": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
remove_features="reward",
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "success" in modified_dataset.meta.features
|
||||
assert "reward" not in modified_dataset.meta.features
|
||||
assert len(modified_dataset) == 50
|
||||
|
||||
|
||||
def test_modify_features_only_add(sample_dataset, tmp_path):
|
||||
"""Test that modify_features works with only add_features."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
|
||||
modified_dataset = modify_features(
|
||||
sample_dataset,
|
||||
add_features={
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
assert len(modified_dataset) == 50
|
||||
|
||||
|
||||
def test_modify_features_only_remove(sample_dataset, tmp_path):
|
||||
"""Test that modify_features works with only remove_features."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
@@ -402,11 +461,46 @@ def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_feature(
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
modified_dataset = modify_features(
|
||||
dataset_with_reward,
|
||||
remove_features="reward",
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "reward" not in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_modify_features_no_changes(sample_dataset, tmp_path):
|
||||
"""Test error when modify_features is called with no changes."""
|
||||
with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"):
|
||||
modify_features(
|
||||
sample_dataset,
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
"""Test removing a single feature."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features = {
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -432,20 +526,19 @@ def test_remove_multiple_features(sample_dataset, tmp_path):
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset = sample_dataset
|
||||
features = {}
|
||||
for feature_name in ["reward", "success"]:
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
dataset = add_feature(
|
||||
dataset,
|
||||
feature_name=feature_name,
|
||||
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / f"with_{feature_name}",
|
||||
features[feature_name] = (
|
||||
np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info,
|
||||
)
|
||||
|
||||
dataset_with_features = add_features(
|
||||
dataset, features=features, output_dir=tmp_path / "with_features"
|
||||
)
|
||||
dataset_clean = remove_feature(
|
||||
dataset,
|
||||
feature_names=["reward", "success"],
|
||||
output_dir=tmp_path / "clean",
|
||||
dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean"
|
||||
)
|
||||
|
||||
assert "reward" not in dataset_clean.meta.features
|
||||
@@ -509,11 +602,14 @@ def test_complex_workflow_integration(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset = add_feature(
|
||||
dataset = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info={"dtype": "float32", "shape": (1,), "names": None},
|
||||
features={
|
||||
"reward": (
|
||||
np.random.randn(50, 1).astype(np.float32),
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
)
|
||||
},
|
||||
output_dir=tmp_path / "step1",
|
||||
)
|
||||
|
||||
@@ -753,7 +849,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f
|
||||
assert "std" in merged.meta.stats[feature]
|
||||
|
||||
|
||||
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
"""Test that adding a feature preserves existing stats."""
|
||||
num_frames = sample_dataset.meta.total_frames
|
||||
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||
@@ -763,6 +859,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
features = {
|
||||
"reward": (reward_values, feature_info),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
@@ -771,11 +870,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -797,11 +894,11 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_feature(
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
features={
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -893,3 +990,60 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
||||
|
||||
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
|
||||
assert total_episodes == sample_dataset.meta.total_episodes
|
||||
|
||||
|
||||
def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
"""Test that modifying features preserves chunk_idx and file_idx from source dataset."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
def mock_snapshot(repo_id, **kwargs):
|
||||
return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1]))
|
||||
|
||||
mock_snapshot_download.side_effect = mock_snapshot
|
||||
|
||||
# First split the dataset to create a non-zero starting chunk/file structure
|
||||
splits = split_dataset(
|
||||
sample_dataset,
|
||||
splits={"train": [0, 1, 2], "val": [3, 4]},
|
||||
output_dir=tmp_path / "splits",
|
||||
)
|
||||
|
||||
train_dataset = splits["train"]
|
||||
|
||||
# Get original chunk/file indices from first episode
|
||||
if train_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
|
||||
original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
|
||||
original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes]
|
||||
|
||||
# Now add a feature to the split dataset
|
||||
modified_dataset = add_features(
|
||||
train_dataset,
|
||||
features={
|
||||
"reward": (
|
||||
np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info,
|
||||
),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
# Check that chunk/file indices are preserved
|
||||
if modified_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
|
||||
new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
|
||||
new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes]
|
||||
|
||||
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
|
||||
assert new_file_indices == original_file_indices, "File indices should be preserved"
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
Reference in New Issue
Block a user