Feat/expand add features (#2202)
* make add_feature take multiple features at a time and rename to add_features * - New function: modify_features that was a combination of remove features and add features. - This function is important for when we want to add a feature and remove another so we can do it in one time to avoid copying and creating the dataset multiple times
This commit is contained in:
@@ -30,9 +30,10 @@ Usage:
|
||||
import numpy as np
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
add_features,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -57,50 +58,56 @@ def main():
|
||||
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||
|
||||
print("\n3. Adding a reward feature...")
|
||||
print("\n3. Adding features...")
|
||||
|
||||
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||
dataset_with_reward = add_feature(
|
||||
dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
repo_id="lerobot/pusht_with_reward",
|
||||
)
|
||||
|
||||
def compute_success(row_dict, episode_index, frame_index):
|
||||
episode_length = 10
|
||||
return float(frame_index >= episode_length - 10)
|
||||
|
||||
dataset_with_success = add_feature(
|
||||
dataset_with_reward,
|
||||
feature_name="success",
|
||||
feature_values=compute_success,
|
||||
feature_info={
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
dataset_with_features = add_features(
|
||||
dataset,
|
||||
features={
|
||||
"reward": (
|
||||
reward_values,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
"success": (
|
||||
compute_success,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
repo_id="lerobot/pusht_with_reward_and_success",
|
||||
repo_id="lerobot/pusht_with_features",
|
||||
)
|
||||
|
||||
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
|
||||
print(f"New features: {list(dataset_with_features.meta.features.keys())}")
|
||||
|
||||
print("\n4. Removing the success feature...")
|
||||
dataset_cleaned = remove_feature(
|
||||
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||
dataset_with_features, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||
)
|
||||
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||
|
||||
print("\n5. Merging train and val splits back together...")
|
||||
print("\n5. Using modify_features to add and remove features simultaneously...")
|
||||
dataset_modified = modify_features(
|
||||
dataset_with_features,
|
||||
add_features={
|
||||
"discount": (
|
||||
np.ones(dataset.meta.total_frames, dtype=np.float32) * 0.99,
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
),
|
||||
},
|
||||
remove_features="reward",
|
||||
repo_id="lerobot/pusht_modified",
|
||||
)
|
||||
print(f"Modified features: {list(dataset_modified.meta.features.keys())}")
|
||||
|
||||
print("\n6. Merging train and val splits back together...")
|
||||
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
|
||||
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||
|
||||
print("\n6. Complex workflow example...")
|
||||
print("\n7. Complex workflow example...")
|
||||
|
||||
if len(dataset.meta.camera_keys) > 1:
|
||||
camera_to_remove = dataset.meta.camera_keys[0]
|
||||
|
||||
@@ -28,8 +28,10 @@ import shutil
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -43,7 +45,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
to_parquet_with_hf_images,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -268,39 +269,79 @@ def merge_datasets(
|
||||
return merged_dataset
|
||||
|
||||
|
||||
def add_feature(
|
||||
def modify_features(
|
||||
dataset: LeRobotDataset,
|
||||
feature_name: str,
|
||||
feature_values: np.ndarray | torch.Tensor | Callable,
|
||||
feature_info: dict,
|
||||
add_features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]] | None = None,
|
||||
remove_features: str | list[str] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Add a new feature to a LeRobotDataset.
|
||||
"""Modify a LeRobotDataset by adding and/or removing features in a single pass.
|
||||
|
||||
This is the most efficient way to modify features, as it only copies the dataset once
|
||||
regardless of how many features are being added or removed.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
feature_name: Name of the new feature.
|
||||
feature_values: Either:
|
||||
- Array/tensor of shape (num_frames, ...) with values for each frame
|
||||
- Callable that takes (frame_dict, episode_index, frame_index) and returns feature value
|
||||
feature_info: Dictionary with feature metadata (dtype, shape, names).
|
||||
add_features: Optional dict mapping feature names to (feature_values, feature_info) tuples.
|
||||
remove_features: Optional feature name(s) to remove. Can be a single string or list.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with features modified.
|
||||
|
||||
Example:
|
||||
new_dataset = modify_features(
|
||||
dataset,
|
||||
add_features={
|
||||
"reward": (reward_array, {"dtype": "float32", "shape": [1], "names": None}),
|
||||
},
|
||||
remove_features=["old_feature"],
|
||||
output_dir="./output",
|
||||
)
|
||||
"""
|
||||
if feature_name in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
|
||||
if add_features is None and remove_features is None:
|
||||
raise ValueError("Must specify at least one of add_features or remove_features")
|
||||
|
||||
remove_features_list: list[str] = []
|
||||
if remove_features is not None:
|
||||
remove_features_list = [remove_features] if isinstance(remove_features, str) else remove_features
|
||||
|
||||
if add_features:
|
||||
required_keys = {"dtype", "shape"}
|
||||
for feature_name, (_, feature_info) in add_features.items():
|
||||
if feature_name in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{feature_name}' already exists in dataset")
|
||||
|
||||
if not required_keys.issubset(feature_info.keys()):
|
||||
raise ValueError(f"feature_info for '{feature_name}' must contain keys: {required_keys}")
|
||||
|
||||
if remove_features_list:
|
||||
for name in remove_features_list:
|
||||
if name not in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{name}' not found in dataset")
|
||||
|
||||
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
|
||||
if any(name in required_features for name in remove_features_list):
|
||||
raise ValueError(f"Cannot remove required features: {required_features}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_modified"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
required_keys = {"dtype", "shape"}
|
||||
if not required_keys.issubset(feature_info.keys()):
|
||||
raise ValueError(f"feature_info must contain keys: {required_keys}")
|
||||
|
||||
new_features = dataset.meta.features.copy()
|
||||
new_features[feature_name] = feature_info
|
||||
|
||||
if remove_features_list:
|
||||
for name in remove_features_list:
|
||||
new_features.pop(name, None)
|
||||
|
||||
if add_features:
|
||||
for feature_name, (_, feature_info) in add_features.items():
|
||||
new_features[feature_name] = feature_info
|
||||
|
||||
video_keys_to_remove = [name for name in remove_features_list if name in dataset.meta.video_keys]
|
||||
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -308,17 +349,18 @@ def add_feature(
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
use_videos=len(remaining_video_keys) > 0,
|
||||
)
|
||||
|
||||
_copy_data_with_feature_changes(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
add_features={feature_name: (feature_values, feature_info)},
|
||||
add_features=add_features,
|
||||
remove_features=remove_features_list if remove_features_list else None,
|
||||
)
|
||||
|
||||
if dataset.meta.video_keys:
|
||||
_copy_videos(dataset, new_meta)
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove if video_keys_to_remove else None)
|
||||
|
||||
new_dataset = LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
@@ -331,6 +373,46 @@ def add_feature(
|
||||
return new_dataset
|
||||
|
||||
|
||||
def add_features(
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, tuple[np.ndarray | torch.Tensor | Callable, dict]],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Add multiple features to a LeRobotDataset in a single pass.
|
||||
|
||||
This is more efficient than calling add_feature() multiple times, as it only
|
||||
copies the dataset once regardless of how many features are being added.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
features: Dictionary mapping feature names to (feature_values, feature_info) tuples.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with all features added.
|
||||
|
||||
Example:
|
||||
features = {
|
||||
"task_embedding": (task_emb_array, {"dtype": "float32", "shape": [384], "names": None}),
|
||||
"cam1_embedding": (cam1_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
|
||||
"cam2_embedding": (cam2_emb_array, {"dtype": "float32", "shape": [768], "names": None}),
|
||||
}
|
||||
new_dataset = add_features(dataset, features, output_dir="./output", repo_id="my_dataset")
|
||||
"""
|
||||
if not features:
|
||||
raise ValueError("No features provided")
|
||||
|
||||
return modify_features(
|
||||
dataset=dataset,
|
||||
add_features=features,
|
||||
remove_features=None,
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
def remove_feature(
|
||||
dataset: LeRobotDataset,
|
||||
feature_names: str | list[str],
|
||||
@@ -345,56 +427,17 @@ def remove_feature(
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_modified" to original.
|
||||
|
||||
Returns:
|
||||
New dataset with features removed.
|
||||
"""
|
||||
if isinstance(feature_names, str):
|
||||
feature_names = [feature_names]
|
||||
|
||||
for name in feature_names:
|
||||
if name not in dataset.meta.features:
|
||||
raise ValueError(f"Feature '{name}' not found in dataset")
|
||||
|
||||
required_features = {"timestamp", "frame_index", "episode_index", "index", "task_index"}
|
||||
if any(name in required_features for name in feature_names):
|
||||
raise ValueError(f"Cannot remove required features: {required_features}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_modified"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
new_features = {k: v for k, v in dataset.meta.features.items() if k not in feature_names}
|
||||
|
||||
video_keys_to_remove = [name for name in feature_names if name in dataset.meta.video_keys]
|
||||
|
||||
remaining_video_keys = [k for k in dataset.meta.video_keys if k not in video_keys_to_remove]
|
||||
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(remaining_video_keys) > 0,
|
||||
)
|
||||
|
||||
_copy_data_with_feature_changes(
|
||||
return modify_features(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
add_features=None,
|
||||
remove_features=feature_names,
|
||||
)
|
||||
|
||||
if new_meta.video_keys:
|
||||
_copy_videos(dataset, new_meta, exclude_keys=video_keys_to_remove)
|
||||
|
||||
new_dataset = LeRobotDataset(
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
root=output_dir,
|
||||
image_transforms=dataset.image_transforms,
|
||||
delta_timestamps=dataset.delta_timestamps,
|
||||
tolerance_s=dataset.tolerance_s,
|
||||
)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
def _fractions_to_episode_indices(
|
||||
total_episodes: int,
|
||||
@@ -501,10 +544,7 @@ def _copy_and_reindex_data(
|
||||
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(dst_meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, dst_path)
|
||||
else:
|
||||
df.to_parquet(dst_path, index=False)
|
||||
_write_parquet(df, dst_path, dst_meta)
|
||||
|
||||
for ep_old_idx in episodes_to_keep:
|
||||
ep_new_idx = episode_mapping[ep_old_idx]
|
||||
@@ -862,6 +902,25 @@ def _copy_and_reindex_episodes_metadata(
|
||||
write_stats(filtered_stats, dst_meta.root)
|
||||
|
||||
|
||||
def _write_parquet(df: pd.DataFrame, path: Path, meta: LeRobotDatasetMetadata) -> None:
|
||||
"""Write DataFrame to parquet
|
||||
|
||||
This ensures images are properly embedded and the file can be loaded correctly by HF datasets.
|
||||
"""
|
||||
from lerobot.datasets.utils import embed_images, get_hf_features_from_features
|
||||
|
||||
hf_features = get_hf_features_from_features(meta.features)
|
||||
ep_dataset = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=hf_features, split="train")
|
||||
|
||||
if len(meta.image_keys) > 0:
|
||||
ep_dataset = embed_images(ep_dataset)
|
||||
|
||||
table = ep_dataset.with_format("arrow")[:]
|
||||
writer = pq.ParquetWriter(path, schema=table.schema, compression="snappy", use_dictionary=True)
|
||||
writer.write_table(table)
|
||||
writer.close()
|
||||
|
||||
|
||||
def _save_data_chunk(
|
||||
df: pd.DataFrame,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
@@ -877,10 +936,7 @@ def _save_data_chunk(
|
||||
path = meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if len(meta.image_keys) > 0:
|
||||
to_parquet_with_hf_images(df, path)
|
||||
else:
|
||||
df.to_parquet(path, index=False)
|
||||
_write_parquet(df, path, meta)
|
||||
|
||||
episode_metadata = {}
|
||||
for ep_idx in df["episode_index"].unique():
|
||||
@@ -906,19 +962,34 @@ def _copy_data_with_feature_changes(
|
||||
remove_features: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Copy data while adding or removing features."""
|
||||
file_paths = set()
|
||||
if dataset.meta.episodes is None:
|
||||
dataset.meta.episodes = load_episodes(dataset.meta.root)
|
||||
|
||||
# Map file paths to episode indices to extract chunk/file indices
|
||||
file_to_episodes: dict[Path, set[int]] = {}
|
||||
for ep_idx in range(dataset.meta.total_episodes):
|
||||
file_paths.add(dataset.meta.get_data_file_path(ep_idx))
|
||||
file_path = dataset.meta.get_data_file_path(ep_idx)
|
||||
if file_path not in file_to_episodes:
|
||||
file_to_episodes[file_path] = set()
|
||||
file_to_episodes[file_path].add(ep_idx)
|
||||
|
||||
frame_idx = 0
|
||||
|
||||
for src_path in tqdm(sorted(file_paths), desc="Processing data files"):
|
||||
for src_path in tqdm(sorted(file_to_episodes.keys()), desc="Processing data files"):
|
||||
df = pd.read_parquet(dataset.root / src_path).reset_index(drop=True)
|
||||
|
||||
# Get chunk_idx and file_idx from the source file's first episode
|
||||
episodes_in_file = file_to_episodes[src_path]
|
||||
first_ep_idx = min(episodes_in_file)
|
||||
src_ep = dataset.meta.episodes[first_ep_idx]
|
||||
chunk_idx = src_ep["data/chunk_index"]
|
||||
file_idx = src_ep["data/file_index"]
|
||||
|
||||
if remove_features:
|
||||
df = df.drop(columns=remove_features, errors="ignore")
|
||||
|
||||
if add_features:
|
||||
end_idx = frame_idx + len(df)
|
||||
for feature_name, (values, _) in add_features.items():
|
||||
if callable(values):
|
||||
feature_values = []
|
||||
@@ -931,15 +1002,18 @@ def _copy_data_with_feature_changes(
|
||||
feature_values.append(value)
|
||||
df[feature_name] = feature_values
|
||||
else:
|
||||
end_idx = frame_idx + len(df)
|
||||
feature_slice = values[frame_idx:end_idx]
|
||||
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
|
||||
df[feature_name] = feature_slice.flatten()
|
||||
else:
|
||||
df[feature_name] = feature_slice
|
||||
frame_idx = end_idx
|
||||
frame_idx = end_idx
|
||||
|
||||
_save_data_chunk(df, new_meta)
|
||||
# Write using the preserved chunk_idx and file_idx from source
|
||||
dst_path = new_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_parquet(df, dst_path, new_meta)
|
||||
|
||||
_copy_episodes_metadata_and_stats(dataset, new_meta)
|
||||
|
||||
|
||||
@@ -22,9 +22,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_feature,
|
||||
add_features,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
@@ -292,7 +293,7 @@ def test_merge_empty_list(tmp_path):
|
||||
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
||||
|
||||
|
||||
def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
def test_add_features_with_values(sample_dataset, tmp_path):
|
||||
"""Test adding a feature with pre-computed values."""
|
||||
num_frames = sample_dataset.meta.total_frames
|
||||
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||
@@ -302,6 +303,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
features = {
|
||||
"reward": (reward_values, feature_info),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
@@ -310,11 +314,9 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -327,7 +329,7 @@ def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||
assert isinstance(sample_item["reward"], torch.Tensor)
|
||||
|
||||
|
||||
def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
def test_add_features_with_callable(sample_dataset, tmp_path):
|
||||
"""Test adding a feature with a callable."""
|
||||
|
||||
def compute_reward(frame_dict, episode_idx, frame_idx):
|
||||
@@ -338,7 +340,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
features = {
|
||||
"reward": (compute_reward, feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
@@ -346,11 +350,9 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=compute_reward,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -368,31 +370,88 @@ def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||
def test_add_existing_feature(sample_dataset, tmp_path):
|
||||
"""Test error when adding an existing feature."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,)}
|
||||
features = {
|
||||
"action": (np.zeros(50), feature_info),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="action",
|
||||
feature_values=np.zeros(50),
|
||||
feature_info=feature_info,
|
||||
add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_add_feature_invalid_info(sample_dataset, tmp_path):
|
||||
"""Test error with invalid feature info."""
|
||||
with pytest.raises(ValueError, match="feature_info must contain keys"):
|
||||
add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.zeros(50),
|
||||
feature_info={"dtype": "float32"},
|
||||
with pytest.raises(ValueError, match="feature_info for 'reward' must contain keys"):
|
||||
add_features(
|
||||
dataset=sample_dataset,
|
||||
features={
|
||||
"reward": (np.zeros(50), {"dtype": "float32"}),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
"""Test removing a single feature."""
|
||||
def test_modify_features_add_and_remove(sample_dataset, tmp_path):
|
||||
"""Test modifying features by adding and removing simultaneously."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
|
||||
# First add a feature we'll later remove
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
# Now use modify_features to add "success" and remove "reward" in one pass
|
||||
modified_dataset = modify_features(
|
||||
dataset_with_reward,
|
||||
add_features={
|
||||
"success": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
remove_features="reward",
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "success" in modified_dataset.meta.features
|
||||
assert "reward" not in modified_dataset.meta.features
|
||||
assert len(modified_dataset) == 50
|
||||
|
||||
|
||||
def test_modify_features_only_add(sample_dataset, tmp_path):
|
||||
"""Test that modify_features works with only add_features."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "modified")
|
||||
|
||||
modified_dataset = modify_features(
|
||||
sample_dataset,
|
||||
add_features={
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
assert len(modified_dataset) == 50
|
||||
|
||||
|
||||
def test_modify_features_only_remove(sample_dataset, tmp_path):
|
||||
"""Test that modify_features works with only remove_features."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
@@ -402,11 +461,46 @@ def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_feature(
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
features={"reward": (np.random.randn(50, 1).astype(np.float32), feature_info)},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
modified_dataset = modify_features(
|
||||
dataset_with_reward,
|
||||
remove_features="reward",
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
assert "reward" not in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_modify_features_no_changes(sample_dataset, tmp_path):
|
||||
"""Test error when modify_features is called with no changes."""
|
||||
with pytest.raises(ValueError, match="Must specify at least one of add_features or remove_features"):
|
||||
modify_features(
|
||||
sample_dataset,
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
|
||||
def test_remove_single_feature(sample_dataset, tmp_path):
|
||||
"""Test removing a single feature."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features = {
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
}
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -432,20 +526,19 @@ def test_remove_multiple_features(sample_dataset, tmp_path):
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset = sample_dataset
|
||||
features = {}
|
||||
for feature_name in ["reward", "success"]:
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
dataset = add_feature(
|
||||
dataset,
|
||||
feature_name=feature_name,
|
||||
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
output_dir=tmp_path / f"with_{feature_name}",
|
||||
features[feature_name] = (
|
||||
np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info,
|
||||
)
|
||||
|
||||
dataset_with_features = add_features(
|
||||
dataset, features=features, output_dir=tmp_path / "with_features"
|
||||
)
|
||||
dataset_clean = remove_feature(
|
||||
dataset,
|
||||
feature_names=["reward", "success"],
|
||||
output_dir=tmp_path / "clean",
|
||||
dataset_with_features, feature_names=["reward", "success"], output_dir=tmp_path / "clean"
|
||||
)
|
||||
|
||||
assert "reward" not in dataset_clean.meta.features
|
||||
@@ -509,11 +602,14 @@ def test_complex_workflow_integration(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset = add_feature(
|
||||
dataset = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info={"dtype": "float32", "shape": (1,), "names": None},
|
||||
features={
|
||||
"reward": (
|
||||
np.random.randn(50, 1).astype(np.float32),
|
||||
{"dtype": "float32", "shape": (1,), "names": None},
|
||||
)
|
||||
},
|
||||
output_dir=tmp_path / "step1",
|
||||
)
|
||||
|
||||
@@ -753,7 +849,7 @@ def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_f
|
||||
assert "std" in merged.meta.stats[feature]
|
||||
|
||||
|
||||
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
def test_add_features_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
"""Test that adding a feature preserves existing stats."""
|
||||
num_frames = sample_dataset.meta.total_frames
|
||||
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||
@@ -763,6 +859,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
features = {
|
||||
"reward": (reward_values, feature_info),
|
||||
}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
@@ -771,11 +870,9 @@ def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||
|
||||
new_dataset = add_feature(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=reward_values,
|
||||
feature_info=feature_info,
|
||||
new_dataset = add_features(
|
||||
dataset=sample_dataset,
|
||||
features=features,
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -797,11 +894,11 @@ def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||
|
||||
dataset_with_reward = add_feature(
|
||||
dataset_with_reward = add_features(
|
||||
sample_dataset,
|
||||
feature_name="reward",
|
||||
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||
feature_info=feature_info,
|
||||
features={
|
||||
"reward": (np.random.randn(50, 1).astype(np.float32), feature_info),
|
||||
},
|
||||
output_dir=tmp_path / "with_reward",
|
||||
)
|
||||
|
||||
@@ -893,3 +990,60 @@ def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
||||
|
||||
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
|
||||
assert total_episodes == sample_dataset.meta.total_episodes
|
||||
|
||||
|
||||
def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
"""Test that modifying features preserves chunk_idx and file_idx from source dataset."""
|
||||
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
|
||||
def mock_snapshot(repo_id, **kwargs):
|
||||
return str(kwargs.get("local_dir", tmp_path / repo_id.split("/")[-1]))
|
||||
|
||||
mock_snapshot_download.side_effect = mock_snapshot
|
||||
|
||||
# First split the dataset to create a non-zero starting chunk/file structure
|
||||
splits = split_dataset(
|
||||
sample_dataset,
|
||||
splits={"train": [0, 1, 2], "val": [3, 4]},
|
||||
output_dir=tmp_path / "splits",
|
||||
)
|
||||
|
||||
train_dataset = splits["train"]
|
||||
|
||||
# Get original chunk/file indices from first episode
|
||||
if train_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
train_dataset.meta.episodes = load_episodes(train_dataset.meta.root)
|
||||
original_chunk_indices = [ep["data/chunk_index"] for ep in train_dataset.meta.episodes]
|
||||
original_file_indices = [ep["data/file_index"] for ep in train_dataset.meta.episodes]
|
||||
|
||||
# Now add a feature to the split dataset
|
||||
modified_dataset = add_features(
|
||||
train_dataset,
|
||||
features={
|
||||
"reward": (
|
||||
np.random.randn(train_dataset.meta.total_frames, 1).astype(np.float32),
|
||||
feature_info,
|
||||
),
|
||||
},
|
||||
output_dir=tmp_path / "modified",
|
||||
)
|
||||
|
||||
# Check that chunk/file indices are preserved
|
||||
if modified_dataset.meta.episodes is None:
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
|
||||
modified_dataset.meta.episodes = load_episodes(modified_dataset.meta.root)
|
||||
new_chunk_indices = [ep["data/chunk_index"] for ep in modified_dataset.meta.episodes]
|
||||
new_file_indices = [ep["data/file_index"] for ep in modified_dataset.meta.episodes]
|
||||
|
||||
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
|
||||
assert new_file_indices == original_file_indices, "File indices should be preserved"
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
Reference in New Issue
Block a user