Dataset tools (#2100)
* feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> --------- Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com> Co-authored-by: Jack Vial <vialjack@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -25,6 +25,8 @@
|
|||||||
title: Using LeRobotDataset
|
title: Using LeRobotDataset
|
||||||
- local: porting_datasets_v3
|
- local: porting_datasets_v3
|
||||||
title: Porting Large Datasets
|
title: Porting Large Datasets
|
||||||
|
- local: using_dataset_tools
|
||||||
|
title: Using the Dataset Tools
|
||||||
title: "Datasets"
|
title: "Datasets"
|
||||||
- sections:
|
- sections:
|
||||||
- local: act
|
- local: act
|
||||||
|
|||||||
102
docs/source/using_dataset_tools.mdx
Normal file
102
docs/source/using_dataset_tools.mdx
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# Using Dataset Tools
|
||||||
|
|
||||||
|
This guide covers the dataset tools utilities available in LeRobot for modifying and editing existing datasets.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
LeRobot provides several utilities for manipulating datasets:
|
||||||
|
|
||||||
|
1. **Delete Episodes** - Remove specific episodes from a dataset
|
||||||
|
2. **Split Dataset** - Divide a dataset into multiple smaller datasets
|
||||||
|
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||||
|
4. **Add Features** - Add new features to a dataset
|
||||||
|
5. **Remove Features** - Remove features from a dataset
|
||||||
|
|
||||||
|
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||||
|
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||||
|
|
||||||
|
## Command-Line Tool: lerobot-edit-dataset
|
||||||
|
|
||||||
|
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||||
|
|
||||||
|
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||||
|
|
||||||
|
### Usage Examples
|
||||||
|
|
||||||
|
#### Delete Episodes
|
||||||
|
|
||||||
|
Remove specific episodes from a dataset. This is useful for filtering out undesired data.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Delete episodes 0, 2, and 5 (modifies original dataset)
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
|
# Delete episodes and save to a new dataset (preserves original dataset)
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_after_deletion \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Split Dataset
|
||||||
|
|
||||||
|
Divide a dataset into multiple subsets.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Split by fractions (e.g. 80% train, 20% test, 20% val)
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"train": 0.8, "test": 0.2, "val": 0.2}'
|
||||||
|
|
||||||
|
# Split by specific episode indices
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"task1": [0, 1, 2, 3], "task2": [4, 5]}'
|
||||||
|
```
|
||||||
|
|
||||||
|
There are no constraints on the split names, they can be determined by the user. Resulting datasets are saved under the repo id with the split name appended, e.g. `lerobot/pusht_train`, `lerobot/pusht_task1`, `lerobot/pusht_task2`.
|
||||||
|
|
||||||
|
#### Merge Datasets
|
||||||
|
|
||||||
|
Combine multiple datasets into a single dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Merge train and validation splits back into one dataset
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht_merged \
|
||||||
|
--operation.type merge \
|
||||||
|
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Remove Features
|
||||||
|
|
||||||
|
Remove features from a dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remove a camera feature
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type remove_feature \
|
||||||
|
--operation.feature_names "['observation.images.top']"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Push to Hub
|
||||||
|
|
||||||
|
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-edit-dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_after_deletion \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]" \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||||
117
examples/dataset/use_dataset_tools.py
Normal file
117
examples/dataset/use_dataset_tools.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example script demonstrating dataset tools utilities.
|
||||||
|
|
||||||
|
This script shows how to:
|
||||||
|
1. Delete episodes from a dataset
|
||||||
|
2. Split a dataset into train/val sets
|
||||||
|
3. Add/remove features
|
||||||
|
4. Merge datasets
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/dataset/use_dataset_tools.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.datasets.dataset_tools import (
|
||||||
|
add_feature,
|
||||||
|
delete_episodes,
|
||||||
|
merge_datasets,
|
||||||
|
remove_feature,
|
||||||
|
split_dataset,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
dataset = LeRobotDataset("lerobot/pusht")
|
||||||
|
|
||||||
|
print(f"Original dataset: {dataset.meta.total_episodes} episodes, {dataset.meta.total_frames} frames")
|
||||||
|
print(f"Features: {list(dataset.meta.features.keys())}")
|
||||||
|
|
||||||
|
print("\n1. Deleting episodes 0 and 2...")
|
||||||
|
filtered_dataset = delete_episodes(dataset, episode_indices=[0, 2], repo_id="lerobot/pusht_filtered")
|
||||||
|
print(f"Filtered dataset: {filtered_dataset.meta.total_episodes} episodes")
|
||||||
|
|
||||||
|
print("\n2. Splitting dataset into train/val...")
|
||||||
|
splits = split_dataset(
|
||||||
|
dataset,
|
||||||
|
splits={"train": 0.8, "val": 0.2},
|
||||||
|
)
|
||||||
|
print(f"Train split: {splits['train'].meta.total_episodes} episodes")
|
||||||
|
print(f"Val split: {splits['val'].meta.total_episodes} episodes")
|
||||||
|
|
||||||
|
print("\n3. Adding a reward feature...")
|
||||||
|
|
||||||
|
reward_values = np.random.randn(dataset.meta.total_frames).astype(np.float32)
|
||||||
|
dataset_with_reward = add_feature(
|
||||||
|
dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=reward_values,
|
||||||
|
feature_info={
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
repo_id="lerobot/pusht_with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_success(row_dict, episode_index, frame_index):
|
||||||
|
episode_length = 10
|
||||||
|
return float(frame_index >= episode_length - 10)
|
||||||
|
|
||||||
|
dataset_with_success = add_feature(
|
||||||
|
dataset_with_reward,
|
||||||
|
feature_name="success",
|
||||||
|
feature_values=compute_success,
|
||||||
|
feature_info={
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
},
|
||||||
|
repo_id="lerobot/pusht_with_reward_and_success",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"New features: {list(dataset_with_success.meta.features.keys())}")
|
||||||
|
|
||||||
|
print("\n4. Removing the success feature...")
|
||||||
|
dataset_cleaned = remove_feature(
|
||||||
|
dataset_with_success, feature_names="success", repo_id="lerobot/pusht_cleaned"
|
||||||
|
)
|
||||||
|
print(f"Features after removal: {list(dataset_cleaned.meta.features.keys())}")
|
||||||
|
|
||||||
|
print("\n5. Merging train and val splits back together...")
|
||||||
|
merged = merge_datasets([splits["train"], splits["val"]], output_repo_id="lerobot/pusht_merged")
|
||||||
|
print(f"Merged dataset: {merged.meta.total_episodes} episodes")
|
||||||
|
|
||||||
|
print("\n6. Complex workflow example...")
|
||||||
|
|
||||||
|
if len(dataset.meta.camera_keys) > 1:
|
||||||
|
camera_to_remove = dataset.meta.camera_keys[0]
|
||||||
|
print(f"Removing camera: {camera_to_remove}")
|
||||||
|
dataset_no_cam = remove_feature(
|
||||||
|
dataset, feature_names=camera_to_remove, repo_id="pusht_no_first_camera"
|
||||||
|
)
|
||||||
|
print(f"Remaining cameras: {dataset_no_cam.meta.camera_keys}")
|
||||||
|
|
||||||
|
print("\nDone! Check ~/.cache/huggingface/lerobot/ for the created datasets.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -67,7 +67,7 @@ dependencies = [
|
|||||||
"cmake>=3.29.0.1,<4.2.0",
|
"cmake>=3.29.0.1,<4.2.0",
|
||||||
"einops>=0.8.0,<0.9.0",
|
"einops>=0.8.0,<0.9.0",
|
||||||
"opencv-python-headless>=4.9.0,<4.13.0",
|
"opencv-python-headless>=4.9.0,<4.13.0",
|
||||||
"av>=14.2.0,<16.0.0",
|
"av>=15.0.0,<16.0.0",
|
||||||
"jsonlines>=4.0.0,<5.0.0",
|
"jsonlines>=4.0.0,<5.0.0",
|
||||||
"packaging>=24.2,<26.0",
|
"packaging>=24.2,<26.0",
|
||||||
"pynput>=1.7.7,<1.9.0",
|
"pynput>=1.7.7,<1.9.0",
|
||||||
@@ -175,6 +175,7 @@ lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
|||||||
lerobot-info="lerobot.scripts.lerobot_info:main"
|
lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||||
|
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ from lerobot.datasets.utils import (
|
|||||||
write_stats,
|
write_stats,
|
||||||
write_tasks,
|
write_tasks,
|
||||||
)
|
)
|
||||||
from lerobot.datasets.video_utils import concatenate_video_files
|
from lerobot.datasets.video_utils import concatenate_video_files, get_video_duration_in_s
|
||||||
|
|
||||||
|
|
||||||
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
def validate_all_metadata(all_metadata: list[LeRobotDatasetMetadata]):
|
||||||
@@ -130,10 +130,34 @@ def update_meta_data(
|
|||||||
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
|
||||||
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
|
# Store original video file indices before updating
|
||||||
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
|
orig_chunk_col = f"videos/{key}/chunk_index"
|
||||||
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
orig_file_col = f"videos/{key}/file_index"
|
||||||
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||||
|
df["_orig_file"] = df[orig_file_col].copy()
|
||||||
|
|
||||||
|
# Update chunk and file indices to point to destination
|
||||||
|
df[orig_chunk_col] = video_idx["chunk"]
|
||||||
|
df[orig_file_col] = video_idx["file"]
|
||||||
|
|
||||||
|
# Apply per-source-file timestamp offsets
|
||||||
|
src_to_offset = video_idx.get("src_to_offset", {})
|
||||||
|
if src_to_offset:
|
||||||
|
# Apply offset based on original source file
|
||||||
|
for idx in df.index:
|
||||||
|
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||||
|
offset = src_to_offset.get(src_key, 0)
|
||||||
|
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||||
|
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||||
|
else:
|
||||||
|
# Fallback to simple offset (for backward compatibility)
|
||||||
|
df[f"videos/{key}/from_timestamp"] = (
|
||||||
|
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||||
|
)
|
||||||
|
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
|
||||||
|
|
||||||
|
# Clean up temporary columns
|
||||||
|
df = df.drop(columns=["_orig_chunk", "_orig_file"])
|
||||||
|
|
||||||
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
|
||||||
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
|
||||||
@@ -193,6 +217,9 @@ def aggregate_datasets(
|
|||||||
robot_type=robot_type,
|
robot_type=robot_type,
|
||||||
features=features,
|
features=features,
|
||||||
root=aggr_root,
|
root=aggr_root,
|
||||||
|
chunks_size=chunk_size,
|
||||||
|
data_files_size_in_mb=data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb=video_files_size_in_mb,
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info("Find all tasks")
|
logging.info("Find all tasks")
|
||||||
@@ -236,6 +263,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
Returns:
|
Returns:
|
||||||
dict: Updated videos_idx with current chunk and file indices.
|
dict: Updated videos_idx with current chunk and file indices.
|
||||||
"""
|
"""
|
||||||
|
for key in videos_idx:
|
||||||
|
videos_idx[key]["episode_duration"] = 0
|
||||||
|
# Track offset for each source (chunk, file) pair
|
||||||
|
videos_idx[key]["src_to_offset"] = {}
|
||||||
|
|
||||||
for key, video_idx in videos_idx.items():
|
for key, video_idx in videos_idx.items():
|
||||||
unique_chunk_file_pairs = {
|
unique_chunk_file_pairs = {
|
||||||
(chunk, file)
|
(chunk, file)
|
||||||
@@ -249,6 +281,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
|
|
||||||
chunk_idx = video_idx["chunk"]
|
chunk_idx = video_idx["chunk"]
|
||||||
file_idx = video_idx["file"]
|
file_idx = video_idx["file"]
|
||||||
|
current_offset = video_idx["latest_duration"]
|
||||||
|
|
||||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
@@ -263,21 +296,24 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
file_index=file_idx,
|
file_index=file_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If a new file is created, we don't want to increment the latest_duration
|
src_duration = get_video_duration_in_s(src_path)
|
||||||
update_latest_duration = False
|
|
||||||
|
|
||||||
if not dst_path.exists():
|
if not dst_path.exists():
|
||||||
# First write to this destination file
|
# Store offset before incrementing
|
||||||
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(str(src_path), str(dst_path))
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
continue # not accumulating further, already copied the file in place
|
videos_idx[key]["episode_duration"] += src_duration
|
||||||
|
current_offset += src_duration
|
||||||
|
continue
|
||||||
|
|
||||||
# Check file sizes before appending
|
|
||||||
src_size = get_video_size_in_mb(src_path)
|
src_size = get_video_size_in_mb(src_path)
|
||||||
dst_size = get_video_size_in_mb(dst_path)
|
dst_size = get_video_size_in_mb(dst_path)
|
||||||
|
|
||||||
if dst_size + src_size >= video_files_size_in_mb:
|
if dst_size + src_size >= video_files_size_in_mb:
|
||||||
# Rotate to a new chunk/file
|
# Rotate to a new file, this source becomes start of new destination
|
||||||
|
# So its offset should be 0
|
||||||
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||||
video_key=key,
|
video_key=key,
|
||||||
@@ -286,25 +322,22 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
|||||||
)
|
)
|
||||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
shutil.copy(str(src_path), str(dst_path))
|
shutil.copy(str(src_path), str(dst_path))
|
||||||
|
# Reset offset for next file
|
||||||
|
current_offset = src_duration
|
||||||
else:
|
else:
|
||||||
# Get the timestamps shift for this video
|
# Append to existing video file - use current accumulated offset
|
||||||
timestamps_shift_s = dst_meta.info["total_frames"] / dst_meta.info["fps"]
|
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||||
|
|
||||||
# Append to existing video file
|
|
||||||
concatenate_video_files(
|
concatenate_video_files(
|
||||||
[dst_path, src_path],
|
[dst_path, src_path],
|
||||||
dst_path,
|
dst_path,
|
||||||
)
|
)
|
||||||
# Update the latest_duration when appending (shifts timestamps!)
|
current_offset += src_duration
|
||||||
update_latest_duration = not update_latest_duration
|
|
||||||
|
videos_idx[key]["episode_duration"] += src_duration
|
||||||
|
|
||||||
# Update the videos_idx with the final chunk and file indices for this key
|
|
||||||
videos_idx[key]["chunk"] = chunk_idx
|
videos_idx[key]["chunk"] = chunk_idx
|
||||||
videos_idx[key]["file"] = file_idx
|
videos_idx[key]["file"] = file_idx
|
||||||
|
|
||||||
if update_latest_duration:
|
|
||||||
videos_idx[key]["latest_duration"] += timestamps_shift_s
|
|
||||||
|
|
||||||
return videos_idx
|
return videos_idx
|
||||||
|
|
||||||
|
|
||||||
@@ -389,9 +422,6 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
videos_idx,
|
videos_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in videos_idx:
|
|
||||||
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
|
||||||
|
|
||||||
meta_idx = append_or_create_parquet_file(
|
meta_idx = append_or_create_parquet_file(
|
||||||
df,
|
df,
|
||||||
src_path,
|
src_path,
|
||||||
@@ -403,6 +433,10 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx):
|
|||||||
aggr_root=dst_meta.root,
|
aggr_root=dst_meta.root,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Increment latest_duration by the total duration added from this source dataset
|
||||||
|
for k in videos_idx:
|
||||||
|
videos_idx[k]["latest_duration"] += videos_idx[k]["episode_duration"]
|
||||||
|
|
||||||
return meta_idx
|
return meta_idx
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
1004
src/lerobot/datasets/dataset_tools.py
Normal file
1004
src/lerobot/datasets/dataset_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -438,6 +438,9 @@ class LeRobotDatasetMetadata:
|
|||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
root: str | Path | None = None,
|
root: str | Path | None = None,
|
||||||
use_videos: bool = True,
|
use_videos: bool = True,
|
||||||
|
chunks_size: int | None = None,
|
||||||
|
data_files_size_in_mb: int | None = None,
|
||||||
|
video_files_size_in_mb: int | None = None,
|
||||||
) -> "LeRobotDatasetMetadata":
|
) -> "LeRobotDatasetMetadata":
|
||||||
"""Creates metadata for a LeRobotDataset."""
|
"""Creates metadata for a LeRobotDataset."""
|
||||||
obj = cls.__new__(cls)
|
obj = cls.__new__(cls)
|
||||||
@@ -452,7 +455,16 @@ class LeRobotDatasetMetadata:
|
|||||||
obj.tasks = None
|
obj.tasks = None
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj.stats = None
|
obj.stats = None
|
||||||
obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, features, use_videos, robot_type)
|
obj.info = create_empty_dataset_info(
|
||||||
|
CODEBASE_VERSION,
|
||||||
|
fps,
|
||||||
|
features,
|
||||||
|
use_videos,
|
||||||
|
robot_type,
|
||||||
|
chunks_size,
|
||||||
|
data_files_size_in_mb,
|
||||||
|
video_files_size_in_mb,
|
||||||
|
)
|
||||||
if len(obj.video_keys) > 0 and not use_videos:
|
if len(obj.video_keys) > 0 and not use_videos:
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
write_json(obj.info, obj.root / INFO_PATH)
|
write_json(obj.info, obj.root / INFO_PATH)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ import pandas
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pyarrow.parquet as pq
|
import pyarrow.parquet as pq
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, concatenate_datasets
|
from datasets import Dataset
|
||||||
from datasets.table import embed_table_storage
|
from datasets.table import embed_table_storage
|
||||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||||
from huggingface_hub.errors import RevisionNotFoundError
|
from huggingface_hub.errors import RevisionNotFoundError
|
||||||
@@ -44,7 +44,7 @@ from lerobot.datasets.backward_compatibility import (
|
|||||||
ForwardCompatibilityError,
|
ForwardCompatibilityError,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_STR
|
||||||
from lerobot.utils.utils import is_valid_numpy_dtype_string
|
from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_string
|
||||||
|
|
||||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||||
@@ -123,8 +123,9 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
|||||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||||
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
|
with SuppressProgressBars():
|
||||||
return concatenate_datasets(datasets)
|
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
|
||||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||||
|
|||||||
@@ -452,6 +452,9 @@ def concatenate_video_files(
|
|||||||
template=input_stream, opaque=True
|
template=input_stream, opaque=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# set the time base to the input stream time base (missing in the codec context)
|
||||||
|
stream_map[input_stream.index].time_base = input_stream.time_base
|
||||||
|
|
||||||
# Demux + remux packets (no re-encode)
|
# Demux + remux packets (no re-encode)
|
||||||
for packet in input_container.demux():
|
for packet in input_container.demux():
|
||||||
# Skip packets from un-mapped streams
|
# Skip packets from un-mapped streams
|
||||||
|
|||||||
286
src/lerobot/scripts/lerobot_edit_dataset.py
Normal file
286
src/lerobot/scripts/lerobot_edit_dataset.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Edit LeRobot datasets using various transformation tools.
|
||||||
|
|
||||||
|
This script allows you to delete episodes, split datasets, merge datasets,
|
||||||
|
and remove features. When new_repo_id is specified, creates a new dataset.
|
||||||
|
|
||||||
|
Usage Examples:
|
||||||
|
|
||||||
|
Delete episodes 0, 2, and 5 from a dataset:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
|
Delete episodes and save to a new dataset:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--new_repo_id lerobot/pusht_filtered \
|
||||||
|
--operation.type delete_episodes \
|
||||||
|
--operation.episode_indices "[0, 2, 5]"
|
||||||
|
|
||||||
|
Split dataset by fractions:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"train": 0.8, "val": 0.2}'
|
||||||
|
|
||||||
|
Split dataset by episode indices:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}'
|
||||||
|
|
||||||
|
Split into more than two splits:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type split \
|
||||||
|
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||||
|
|
||||||
|
Merge multiple datasets:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht_merged \
|
||||||
|
--operation.type merge \
|
||||||
|
--operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']"
|
||||||
|
|
||||||
|
Remove camera feature:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--repo_id lerobot/pusht \
|
||||||
|
--operation.type remove_feature \
|
||||||
|
--operation.feature_names "['observation.images.top']"
|
||||||
|
|
||||||
|
Using JSON config file:
|
||||||
|
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||||
|
--config_path path/to/edit_config.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lerobot.configs import parser
|
||||||
|
from lerobot.datasets.dataset_tools import (
|
||||||
|
delete_episodes,
|
||||||
|
merge_datasets,
|
||||||
|
remove_feature,
|
||||||
|
split_dataset,
|
||||||
|
)
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||||
|
from lerobot.utils.utils import init_logging
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeleteEpisodesConfig:
|
||||||
|
type: str = "delete_episodes"
|
||||||
|
episode_indices: list[int] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SplitConfig:
|
||||||
|
type: str = "split"
|
||||||
|
splits: dict[str, float | list[int]] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MergeConfig:
|
||||||
|
type: str = "merge"
|
||||||
|
repo_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RemoveFeatureConfig:
|
||||||
|
type: str = "remove_feature"
|
||||||
|
feature_names: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EditDatasetConfig:
|
||||||
|
repo_id: str
|
||||||
|
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
|
||||||
|
root: str | None = None
|
||||||
|
new_repo_id: str | None = None
|
||||||
|
push_to_hub: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_path(repo_id: str, new_repo_id: str | None, root: Path | None) -> tuple[str, Path]:
|
||||||
|
if new_repo_id:
|
||||||
|
output_repo_id = new_repo_id
|
||||||
|
output_dir = root / new_repo_id if root else HF_LEROBOT_HOME / new_repo_id
|
||||||
|
else:
|
||||||
|
output_repo_id = repo_id
|
||||||
|
dataset_path = root / repo_id if root else HF_LEROBOT_HOME / repo_id
|
||||||
|
old_path = Path(str(dataset_path) + "_old")
|
||||||
|
|
||||||
|
if dataset_path.exists():
|
||||||
|
if old_path.exists():
|
||||||
|
shutil.rmtree(old_path)
|
||||||
|
shutil.move(str(dataset_path), str(old_path))
|
||||||
|
|
||||||
|
output_dir = dataset_path
|
||||||
|
|
||||||
|
return output_repo_id, output_dir
|
||||||
|
|
||||||
|
|
||||||
|
def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, DeleteEpisodesConfig):
|
||||||
|
raise ValueError("Operation config must be DeleteEpisodesConfig")
|
||||||
|
|
||||||
|
if not cfg.operation.episode_indices:
|
||||||
|
raise ValueError("episode_indices must be specified for delete_episodes operation")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
output_repo_id, output_dir = get_output_path(
|
||||||
|
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.new_repo_id is None:
|
||||||
|
dataset.root = Path(str(dataset.root) + "_old")
|
||||||
|
|
||||||
|
logging.info(f"Deleting episodes {cfg.operation.episode_indices} from {cfg.repo_id}")
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
dataset,
|
||||||
|
episode_indices=cfg.operation.episode_indices,
|
||||||
|
output_dir=output_dir,
|
||||||
|
repo_id=output_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Dataset saved to {output_dir}")
|
||||||
|
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||||
|
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_split(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, SplitConfig):
|
||||||
|
raise ValueError("Operation config must be SplitConfig")
|
||||||
|
|
||||||
|
if not cfg.operation.splits:
|
||||||
|
raise ValueError(
|
||||||
|
"splits dict must be specified with split names as keys and fractions/episode lists as values"
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
|
||||||
|
logging.info(f"Splitting dataset {cfg.repo_id} with splits: {cfg.operation.splits}")
|
||||||
|
split_datasets = split_dataset(dataset, splits=cfg.operation.splits)
|
||||||
|
|
||||||
|
for split_name, split_ds in split_datasets.items():
|
||||||
|
split_repo_id = f"{cfg.repo_id}_{split_name}"
|
||||||
|
logging.info(
|
||||||
|
f"{split_name}: {split_ds.meta.total_episodes} episodes, {split_ds.meta.total_frames} frames"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing {split_name} split to hub as {split_repo_id}")
|
||||||
|
LeRobotDataset(split_ds.repo_id, root=split_ds.root).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_merge(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, MergeConfig):
|
||||||
|
raise ValueError("Operation config must be MergeConfig")
|
||||||
|
|
||||||
|
if not cfg.operation.repo_ids:
|
||||||
|
raise ValueError("repo_ids must be specified for merge operation")
|
||||||
|
|
||||||
|
if not cfg.repo_id:
|
||||||
|
raise ValueError("repo_id must be specified as the output repository for merged dataset")
|
||||||
|
|
||||||
|
logging.info(f"Loading {len(cfg.operation.repo_ids)} datasets to merge")
|
||||||
|
datasets = [LeRobotDataset(repo_id, root=cfg.root) for repo_id in cfg.operation.repo_ids]
|
||||||
|
|
||||||
|
output_dir = Path(cfg.root) / cfg.repo_id if cfg.root else HF_LEROBOT_HOME / cfg.repo_id
|
||||||
|
|
||||||
|
logging.info(f"Merging datasets into {cfg.repo_id}")
|
||||||
|
merged_dataset = merge_datasets(
|
||||||
|
datasets,
|
||||||
|
output_repo_id=cfg.repo_id,
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Merged dataset saved to {output_dir}")
|
||||||
|
logging.info(
|
||||||
|
f"Episodes: {merged_dataset.meta.total_episodes}, Frames: {merged_dataset.meta.total_frames}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {cfg.repo_id}")
|
||||||
|
LeRobotDataset(merged_dataset.repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||||
|
if not isinstance(cfg.operation, RemoveFeatureConfig):
|
||||||
|
raise ValueError("Operation config must be RemoveFeatureConfig")
|
||||||
|
|
||||||
|
if not cfg.operation.feature_names:
|
||||||
|
raise ValueError("feature_names must be specified for remove_feature operation")
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||||
|
output_repo_id, output_dir = get_output_path(
|
||||||
|
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.new_repo_id is None:
|
||||||
|
dataset.root = Path(str(dataset.root) + "_old")
|
||||||
|
|
||||||
|
logging.info(f"Removing features {cfg.operation.feature_names} from {cfg.repo_id}")
|
||||||
|
new_dataset = remove_feature(
|
||||||
|
dataset,
|
||||||
|
feature_names=cfg.operation.feature_names,
|
||||||
|
output_dir=output_dir,
|
||||||
|
repo_id=output_repo_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Dataset saved to {output_dir}")
|
||||||
|
logging.info(f"Remaining features: {list(new_dataset.meta.features.keys())}")
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||||
|
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
@parser.wrap()
|
||||||
|
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||||
|
operation_type = cfg.operation.type
|
||||||
|
|
||||||
|
if operation_type == "delete_episodes":
|
||||||
|
handle_delete_episodes(cfg)
|
||||||
|
elif operation_type == "split":
|
||||||
|
handle_split(cfg)
|
||||||
|
elif operation_type == "merge":
|
||||||
|
handle_merge(cfg)
|
||||||
|
elif operation_type == "remove_feature":
|
||||||
|
handle_remove_feature(cfg)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown operation type: {operation_type}\n"
|
||||||
|
f"Available operations: delete_episodes, split, merge, remove_feature"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
init_logging()
|
||||||
|
edit_dataset()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -27,6 +27,7 @@ from statistics import mean
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
|
||||||
|
|
||||||
|
|
||||||
def inside_slurm():
|
def inside_slurm():
|
||||||
@@ -247,6 +248,25 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
|||||||
return days, hours, minutes, seconds
|
return days, hours, minutes, seconds
|
||||||
|
|
||||||
|
|
||||||
|
class SuppressProgressBars:
|
||||||
|
"""
|
||||||
|
Context manager to suppress progress bars.
|
||||||
|
|
||||||
|
Example
|
||||||
|
--------
|
||||||
|
```python
|
||||||
|
with SuppressProgressBars():
|
||||||
|
# Code that would normally show progress bars
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
disable_progress_bar()
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
enable_progress_bar()
|
||||||
|
|
||||||
|
|
||||||
class TimerManager:
|
class TimerManager:
|
||||||
"""
|
"""
|
||||||
Lightweight utility to measure elapsed time.
|
Lightweight utility to measure elapsed time.
|
||||||
|
|||||||
@@ -181,6 +181,54 @@ def assert_dataset_iteration_works(aggr_ds):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def assert_video_timestamps_within_bounds(aggr_ds):
|
||||||
|
"""Test that all video timestamps are within valid bounds for their respective video files.
|
||||||
|
|
||||||
|
This catches bugs where timestamps point to frames beyond the actual video length,
|
||||||
|
which would cause "Invalid frame index" errors during data loading.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from torchcodec.decoders import VideoDecoder
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
|
|
||||||
|
for ep_idx in range(aggr_ds.num_episodes):
|
||||||
|
ep = aggr_ds.meta.episodes[ep_idx]
|
||||||
|
|
||||||
|
for vid_key in aggr_ds.meta.video_keys:
|
||||||
|
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
|
||||||
|
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
|
||||||
|
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
|
||||||
|
if not video_path.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
from_frame_idx = round(from_ts * aggr_ds.fps)
|
||||||
|
to_frame_idx = round(to_ts * aggr_ds.fps)
|
||||||
|
|
||||||
|
try:
|
||||||
|
decoder = VideoDecoder(str(video_path))
|
||||||
|
num_frames = len(decoder)
|
||||||
|
|
||||||
|
# Verify timestamps don't exceed video bounds
|
||||||
|
assert from_frame_idx >= 0, (
|
||||||
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
|
||||||
|
)
|
||||||
|
assert from_frame_idx < num_frames, (
|
||||||
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
|
||||||
|
)
|
||||||
|
assert to_frame_idx <= num_frames, (
|
||||||
|
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
|
||||||
|
)
|
||||||
|
assert from_frame_idx < to_frame_idx, (
|
||||||
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||||
"""Test basic aggregation functionality with standard parameters."""
|
"""Test basic aggregation functionality with standard parameters."""
|
||||||
ds_0_num_frames = 400
|
ds_0_num_frames = 400
|
||||||
@@ -227,6 +275,7 @@ def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
|||||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_video_timestamps_within_bounds(aggr_ds)
|
||||||
assert_dataset_iteration_works(aggr_ds)
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
|
|
||||||
|
|
||||||
@@ -277,6 +326,7 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
|||||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||||
|
assert_video_timestamps_within_bounds(aggr_ds)
|
||||||
assert_dataset_iteration_works(aggr_ds)
|
assert_dataset_iteration_works(aggr_ds)
|
||||||
|
|
||||||
# Check that multiple files were actually created due to small size limits
|
# Check that multiple files were actually created due to small size limits
|
||||||
@@ -290,3 +340,43 @@ def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
|||||||
if video_dir.exists():
|
if video_dir.exists():
|
||||||
video_files = list(video_dir.rglob("*.mp4"))
|
video_files = list(video_dir.rglob("*.mp4"))
|
||||||
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
||||||
|
|
||||||
|
|
||||||
|
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
||||||
|
"""Regression test for video timestamp bug when merging datasets.
|
||||||
|
|
||||||
|
This test specifically checks that video timestamps are correctly calculated
|
||||||
|
and accumulated when merging multiple datasets.
|
||||||
|
"""
|
||||||
|
datasets = []
|
||||||
|
for i in range(3):
|
||||||
|
ds = lerobot_dataset_factory(
|
||||||
|
root=tmp_path / f"regression_{i}",
|
||||||
|
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
||||||
|
total_episodes=2,
|
||||||
|
total_frames=100,
|
||||||
|
)
|
||||||
|
datasets.append(ds)
|
||||||
|
|
||||||
|
aggregate_datasets(
|
||||||
|
repo_ids=[ds.repo_id for ds in datasets],
|
||||||
|
roots=[ds.root for ds in datasets],
|
||||||
|
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
|
||||||
|
aggr_root=tmp_path / "regression_aggr",
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
|
||||||
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
||||||
|
|
||||||
|
assert_video_timestamps_within_bounds(aggr_ds)
|
||||||
|
|
||||||
|
for i in range(len(aggr_ds)):
|
||||||
|
item = aggr_ds[i]
|
||||||
|
for key in aggr_ds.meta.video_keys:
|
||||||
|
assert key in item, f"Video key {key} missing from item {i}"
|
||||||
|
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|
||||||
|
|||||||
891
tests/datasets/test_dataset_tools.py
Normal file
891
tests/datasets/test_dataset_tools.py
Normal file
@@ -0,0 +1,891 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Tests for dataset tools utilities."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.datasets.dataset_tools import (
|
||||||
|
add_feature,
|
||||||
|
delete_episodes,
|
||||||
|
merge_datasets,
|
||||||
|
remove_feature,
|
||||||
|
split_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Create a sample dataset for testing."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||||
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = empty_lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_dataset",
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx in range(5):
|
||||||
|
for _ in range(10):
|
||||||
|
frame = {
|
||||||
|
"action": np.random.randn(6).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(4).astype(np.float32),
|
||||||
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||||
|
"task": f"task_{ep_idx % 2}",
|
||||||
|
}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_single_episode(sample_dataset, tmp_path):
|
||||||
|
"""Test deleting a single episode."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[2],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.total_episodes == 4
|
||||||
|
assert new_dataset.meta.total_frames == 40
|
||||||
|
|
||||||
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
||||||
|
assert episode_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
assert len(new_dataset) == 40
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_multiple_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test deleting multiple episodes."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[1, 3],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.total_episodes == 3
|
||||||
|
assert new_dataset.meta.total_frames == 30
|
||||||
|
|
||||||
|
episode_indices = {int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]}
|
||||||
|
assert episode_indices == {0, 1, 2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_invalid_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test error handling for invalid episode indices."""
|
||||||
|
with pytest.raises(ValueError, match="Invalid episode indices"):
|
||||||
|
delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[10, 20],
|
||||||
|
output_dir=tmp_path / "filtered",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_all_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test error when trying to delete all episodes."""
|
||||||
|
with pytest.raises(ValueError, match="Cannot delete all episodes"):
|
||||||
|
delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=list(range(5)),
|
||||||
|
output_dir=tmp_path / "filtered",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_empty_list(sample_dataset, tmp_path):
|
||||||
|
"""Test error when no episodes specified."""
|
||||||
|
with pytest.raises(ValueError, match="No episodes to delete"):
|
||||||
|
delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[],
|
||||||
|
output_dir=tmp_path / "filtered",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_by_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test splitting dataset by specific episode indices."""
|
||||||
|
splits = {
|
||||||
|
"train": [0, 1, 2],
|
||||||
|
"val": [3, 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
|
||||||
|
def mock_snapshot(repo_id, **kwargs):
|
||||||
|
if "train" in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_train")
|
||||||
|
elif "val" in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_val")
|
||||||
|
return str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
mock_snapshot_download.side_effect = mock_snapshot
|
||||||
|
|
||||||
|
result = split_dataset(
|
||||||
|
sample_dataset,
|
||||||
|
splits=splits,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(result.keys()) == {"train", "val"}
|
||||||
|
|
||||||
|
assert result["train"].meta.total_episodes == 3
|
||||||
|
assert result["train"].meta.total_frames == 30
|
||||||
|
|
||||||
|
assert result["val"].meta.total_episodes == 2
|
||||||
|
assert result["val"].meta.total_frames == 20
|
||||||
|
|
||||||
|
train_episodes = {int(idx.item()) for idx in result["train"].hf_dataset["episode_index"]}
|
||||||
|
assert train_episodes == {0, 1, 2}
|
||||||
|
|
||||||
|
val_episodes = {int(idx.item()) for idx in result["val"].hf_dataset["episode_index"]}
|
||||||
|
assert val_episodes == {0, 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_by_fractions(sample_dataset, tmp_path):
|
||||||
|
"""Test splitting dataset by fractions."""
|
||||||
|
splits = {
|
||||||
|
"train": 0.6,
|
||||||
|
"val": 0.4,
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
|
||||||
|
def mock_snapshot(repo_id, **kwargs):
|
||||||
|
for split_name in splits:
|
||||||
|
if split_name in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
||||||
|
return str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
mock_snapshot_download.side_effect = mock_snapshot
|
||||||
|
|
||||||
|
result = split_dataset(
|
||||||
|
sample_dataset,
|
||||||
|
splits=splits,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["train"].meta.total_episodes == 3
|
||||||
|
assert result["val"].meta.total_episodes == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_overlapping_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test error when episodes appear in multiple splits."""
|
||||||
|
splits = {
|
||||||
|
"train": [0, 1, 2],
|
||||||
|
"val": [2, 3, 4],
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Episodes cannot appear in multiple splits"):
|
||||||
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_invalid_fractions(sample_dataset, tmp_path):
|
||||||
|
"""Test error when fractions sum to more than 1."""
|
||||||
|
splits = {
|
||||||
|
"train": 0.7,
|
||||||
|
"val": 0.5,
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Split fractions must sum to <= 1.0"):
|
||||||
|
split_dataset(sample_dataset, splits=splits, output_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_empty(sample_dataset, tmp_path):
|
||||||
|
"""Test error with empty splits."""
|
||||||
|
with pytest.raises(ValueError, match="No splits provided"):
|
||||||
|
split_dataset(sample_dataset, splits={}, output_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_two_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Test merging two datasets."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||||
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset2 = empty_lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_dataset2",
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx in range(3):
|
||||||
|
for _ in range(10):
|
||||||
|
frame = {
|
||||||
|
"action": np.random.randn(6).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(4).astype(np.float32),
|
||||||
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||||
|
"task": f"task_{ep_idx % 2}",
|
||||||
|
}
|
||||||
|
dataset2.add_frame(frame)
|
||||||
|
dataset2.save_episode()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||||
|
|
||||||
|
merged = merge_datasets(
|
||||||
|
[sample_dataset, dataset2],
|
||||||
|
output_repo_id="merged_dataset",
|
||||||
|
output_dir=tmp_path / "merged_dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert merged.meta.total_episodes == 8 # 5 + 3
|
||||||
|
assert merged.meta.total_frames == 80 # 50 + 30
|
||||||
|
|
||||||
|
episode_indices = sorted({int(idx.item()) for idx in merged.hf_dataset["episode_index"]})
|
||||||
|
assert episode_indices == list(range(8))
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_empty_list(tmp_path):
|
||||||
|
"""Test error when merging empty list."""
|
||||||
|
with pytest.raises(ValueError, match="No datasets to merge"):
|
||||||
|
merge_datasets([], output_repo_id="merged", output_dir=tmp_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_feature_with_values(sample_dataset, tmp_path):
|
||||||
|
"""Test adding a feature with pre-computed values."""
|
||||||
|
num_frames = sample_dataset.meta.total_frames
|
||||||
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||||
|
|
||||||
|
feature_info = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||||
|
|
||||||
|
new_dataset = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=reward_values,
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "reward" in new_dataset.meta.features
|
||||||
|
assert new_dataset.meta.features["reward"] == feature_info
|
||||||
|
|
||||||
|
assert len(new_dataset) == num_frames
|
||||||
|
sample_item = new_dataset[0]
|
||||||
|
assert "reward" in sample_item
|
||||||
|
assert isinstance(sample_item["reward"], torch.Tensor)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_feature_with_callable(sample_dataset, tmp_path):
|
||||||
|
"""Test adding a feature with a callable."""
|
||||||
|
|
||||||
|
def compute_reward(frame_dict, episode_idx, frame_idx):
|
||||||
|
return float(episode_idx * 10 + frame_idx)
|
||||||
|
|
||||||
|
feature_info = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||||
|
|
||||||
|
new_dataset = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=compute_reward,
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "reward" in new_dataset.meta.features
|
||||||
|
|
||||||
|
items = [new_dataset[i] for i in range(10)]
|
||||||
|
first_episode_items = [item for item in items if item["episode_index"] == 0]
|
||||||
|
assert len(first_episode_items) == 10
|
||||||
|
|
||||||
|
first_frame = first_episode_items[0]
|
||||||
|
assert first_frame["frame_index"] == 0
|
||||||
|
assert float(first_frame["reward"]) == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_existing_feature(sample_dataset, tmp_path):
|
||||||
|
"""Test error when adding an existing feature."""
|
||||||
|
feature_info = {"dtype": "float32", "shape": (1,)}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Feature 'action' already exists"):
|
||||||
|
add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="action",
|
||||||
|
feature_values=np.zeros(50),
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "modified",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_feature_invalid_info(sample_dataset, tmp_path):
|
||||||
|
"""Test error with invalid feature info."""
|
||||||
|
with pytest.raises(ValueError, match="feature_info must contain keys"):
|
||||||
|
add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=np.zeros(50),
|
||||||
|
feature_info={"dtype": "float32"},
|
||||||
|
output_dir=tmp_path / "modified",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_single_feature(sample_dataset, tmp_path):
|
||||||
|
"""Test removing a single feature."""
|
||||||
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
dataset_with_reward = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_without_reward = remove_feature(
|
||||||
|
dataset_with_reward,
|
||||||
|
feature_names="reward",
|
||||||
|
output_dir=tmp_path / "without_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "reward" not in dataset_without_reward.meta.features
|
||||||
|
|
||||||
|
sample_item = dataset_without_reward[0]
|
||||||
|
assert "reward" not in sample_item
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_multiple_features(sample_dataset, tmp_path):
|
||||||
|
"""Test removing multiple features at once."""
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
dataset = sample_dataset
|
||||||
|
for feature_name in ["reward", "success"]:
|
||||||
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||||
|
dataset = add_feature(
|
||||||
|
dataset,
|
||||||
|
feature_name=feature_name,
|
||||||
|
feature_values=np.random.randn(dataset.meta.total_frames, 1).astype(np.float32),
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / f"with_{feature_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_clean = remove_feature(
|
||||||
|
dataset,
|
||||||
|
feature_names=["reward", "success"],
|
||||||
|
output_dir=tmp_path / "clean",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "reward" not in dataset_clean.meta.features
|
||||||
|
assert "success" not in dataset_clean.meta.features
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_nonexistent_feature(sample_dataset, tmp_path):
|
||||||
|
"""Test error when removing non-existent feature."""
|
||||||
|
with pytest.raises(ValueError, match="Feature 'nonexistent' not found"):
|
||||||
|
remove_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_names="nonexistent",
|
||||||
|
output_dir=tmp_path / "modified",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_required_feature(sample_dataset, tmp_path):
|
||||||
|
"""Test error when trying to remove required features."""
|
||||||
|
with pytest.raises(ValueError, match="Cannot remove required features"):
|
||||||
|
remove_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_names="timestamp",
|
||||||
|
output_dir=tmp_path / "modified",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_camera_feature(sample_dataset, tmp_path):
|
||||||
|
"""Test removing a camera feature."""
|
||||||
|
camera_keys = sample_dataset.meta.camera_keys
|
||||||
|
if not camera_keys:
|
||||||
|
pytest.skip("No camera keys in dataset")
|
||||||
|
|
||||||
|
camera_to_remove = camera_keys[0]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "without_camera")
|
||||||
|
|
||||||
|
dataset_without_camera = remove_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_names=camera_to_remove,
|
||||||
|
output_dir=tmp_path / "without_camera",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert camera_to_remove not in dataset_without_camera.meta.features
|
||||||
|
assert camera_to_remove not in dataset_without_camera.meta.camera_keys
|
||||||
|
|
||||||
|
sample_item = dataset_without_camera[0]
|
||||||
|
assert camera_to_remove not in sample_item
|
||||||
|
|
||||||
|
|
||||||
|
def test_complex_workflow_integration(sample_dataset, tmp_path):
|
||||||
|
"""Test a complex workflow combining multiple operations."""
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
dataset = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||||
|
feature_info={"dtype": "float32", "shape": (1,), "names": None},
|
||||||
|
output_dir=tmp_path / "step1",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = delete_episodes(
|
||||||
|
dataset,
|
||||||
|
episode_indices=[2],
|
||||||
|
output_dir=tmp_path / "step2",
|
||||||
|
)
|
||||||
|
|
||||||
|
splits = split_dataset(
|
||||||
|
dataset,
|
||||||
|
splits={"train": 0.75, "val": 0.25},
|
||||||
|
output_dir=tmp_path / "step3",
|
||||||
|
)
|
||||||
|
|
||||||
|
merged = merge_datasets(
|
||||||
|
list(splits.values()),
|
||||||
|
output_repo_id="final_dataset",
|
||||||
|
output_dir=tmp_path / "step4",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert merged.meta.total_episodes == 4
|
||||||
|
assert merged.meta.total_frames == 40
|
||||||
|
assert "reward" in merged.meta.features
|
||||||
|
|
||||||
|
assert len(merged) == 40
|
||||||
|
sample_item = merged[0]
|
||||||
|
assert "reward" in sample_item
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_episodes_preserves_stats(sample_dataset, tmp_path):
|
||||||
|
"""Test that deleting episodes preserves statistics correctly."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[2],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.stats is not None
|
||||||
|
for feature in ["action", "observation.state"]:
|
||||||
|
assert feature in new_dataset.meta.stats
|
||||||
|
assert "mean" in new_dataset.meta.stats[feature]
|
||||||
|
assert "std" in new_dataset.meta.stats[feature]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_episodes_preserves_tasks(sample_dataset, tmp_path):
|
||||||
|
"""Test that tasks are preserved correctly after deletion."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[0],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.tasks is not None
|
||||||
|
assert len(new_dataset.meta.tasks) == 2
|
||||||
|
|
||||||
|
tasks_in_dataset = {str(item["task"]) for item in new_dataset}
|
||||||
|
assert len(tasks_in_dataset) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_three_ways(sample_dataset, tmp_path):
|
||||||
|
"""Test splitting dataset into three splits."""
|
||||||
|
splits = {
|
||||||
|
"train": 0.6,
|
||||||
|
"val": 0.2,
|
||||||
|
"test": 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
|
||||||
|
def mock_snapshot(repo_id, **kwargs):
|
||||||
|
for split_name in splits:
|
||||||
|
if split_name in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
||||||
|
return str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
mock_snapshot_download.side_effect = mock_snapshot
|
||||||
|
|
||||||
|
result = split_dataset(
|
||||||
|
sample_dataset,
|
||||||
|
splits=splits,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(result.keys()) == {"train", "val", "test"}
|
||||||
|
assert result["train"].meta.total_episodes == 3
|
||||||
|
assert result["val"].meta.total_episodes == 1
|
||||||
|
assert result["test"].meta.total_episodes == 1
|
||||||
|
|
||||||
|
total_frames = sum(ds.meta.total_frames for ds in result.values())
|
||||||
|
assert total_frames == sample_dataset.meta.total_frames
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_preserves_stats(sample_dataset, tmp_path):
|
||||||
|
"""Test that statistics are preserved when splitting."""
|
||||||
|
splits = {"train": [0, 1, 2], "val": [3, 4]}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
|
||||||
|
def mock_snapshot(repo_id, **kwargs):
|
||||||
|
for split_name in splits:
|
||||||
|
if split_name in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
||||||
|
return str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
mock_snapshot_download.side_effect = mock_snapshot
|
||||||
|
|
||||||
|
result = split_dataset(
|
||||||
|
sample_dataset,
|
||||||
|
splits=splits,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
for split_ds in result.values():
|
||||||
|
assert split_ds.meta.stats is not None
|
||||||
|
for feature in ["action", "observation.state"]:
|
||||||
|
assert feature in split_ds.meta.stats
|
||||||
|
assert "mean" in split_ds.meta.stats[feature]
|
||||||
|
assert "std" in split_ds.meta.stats[feature]
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_three_datasets(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Test merging three datasets."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||||
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
datasets = [sample_dataset]
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
dataset = empty_lerobot_dataset_factory(
|
||||||
|
root=tmp_path / f"test_dataset{i + 2}",
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx in range(2):
|
||||||
|
for _ in range(10):
|
||||||
|
frame = {
|
||||||
|
"action": np.random.randn(6).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(4).astype(np.float32),
|
||||||
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||||
|
"task": f"task_{ep_idx}",
|
||||||
|
}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
datasets.append(dataset)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||||
|
|
||||||
|
merged = merge_datasets(
|
||||||
|
datasets,
|
||||||
|
output_repo_id="merged_dataset",
|
||||||
|
output_dir=tmp_path / "merged_dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert merged.meta.total_episodes == 9
|
||||||
|
assert merged.meta.total_frames == 90
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_preserves_stats(sample_dataset, tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Test that statistics are computed for merged datasets."""
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (6,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (4,), "names": None},
|
||||||
|
"observation.images.top": {"dtype": "image", "shape": (224, 224, 3), "names": None},
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset2 = empty_lerobot_dataset_factory(
|
||||||
|
root=tmp_path / "test_dataset2",
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
for ep_idx in range(3):
|
||||||
|
for _ in range(10):
|
||||||
|
frame = {
|
||||||
|
"action": np.random.randn(6).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(4).astype(np.float32),
|
||||||
|
"observation.images.top": np.random.randint(0, 255, size=(224, 224, 3), dtype=np.uint8),
|
||||||
|
"task": f"task_{ep_idx % 2}",
|
||||||
|
}
|
||||||
|
dataset2.add_frame(frame)
|
||||||
|
dataset2.save_episode()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "merged_dataset")
|
||||||
|
|
||||||
|
merged = merge_datasets(
|
||||||
|
[sample_dataset, dataset2],
|
||||||
|
output_repo_id="merged_dataset",
|
||||||
|
output_dir=tmp_path / "merged_dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert merged.meta.stats is not None
|
||||||
|
for feature in ["action", "observation.state"]:
|
||||||
|
assert feature in merged.meta.stats
|
||||||
|
assert "mean" in merged.meta.stats[feature]
|
||||||
|
assert "std" in merged.meta.stats[feature]
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_feature_preserves_existing_stats(sample_dataset, tmp_path):
|
||||||
|
"""Test that adding a feature preserves existing stats."""
|
||||||
|
num_frames = sample_dataset.meta.total_frames
|
||||||
|
reward_values = np.random.randn(num_frames, 1).astype(np.float32)
|
||||||
|
|
||||||
|
feature_info = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (1,),
|
||||||
|
"names": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(tmp_path / "with_reward")
|
||||||
|
|
||||||
|
new_dataset = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=reward_values,
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.stats is not None
|
||||||
|
for feature in ["action", "observation.state"]:
|
||||||
|
assert feature in new_dataset.meta.stats
|
||||||
|
assert "mean" in new_dataset.meta.stats[feature]
|
||||||
|
assert "std" in new_dataset.meta.stats[feature]
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_feature_updates_stats(sample_dataset, tmp_path):
|
||||||
|
"""Test that removing a feature removes it from stats."""
|
||||||
|
feature_info = {"dtype": "float32", "shape": (1,), "names": None}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
dataset_with_reward = add_feature(
|
||||||
|
sample_dataset,
|
||||||
|
feature_name="reward",
|
||||||
|
feature_values=np.random.randn(50, 1).astype(np.float32),
|
||||||
|
feature_info=feature_info,
|
||||||
|
output_dir=tmp_path / "with_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_without_reward = remove_feature(
|
||||||
|
dataset_with_reward,
|
||||||
|
feature_names="reward",
|
||||||
|
output_dir=tmp_path / "without_reward",
|
||||||
|
)
|
||||||
|
|
||||||
|
if dataset_without_reward.meta.stats:
|
||||||
|
assert "reward" not in dataset_without_reward.meta.stats
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_consecutive_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test deleting consecutive episodes."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[1, 2, 3],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.total_episodes == 2
|
||||||
|
assert new_dataset.meta.total_frames == 20
|
||||||
|
|
||||||
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
||||||
|
assert episode_indices == [0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_first_and_last_episodes(sample_dataset, tmp_path):
|
||||||
|
"""Test deleting first and last episodes."""
|
||||||
|
output_dir = tmp_path / "filtered"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
mock_snapshot_download.return_value = str(output_dir)
|
||||||
|
|
||||||
|
new_dataset = delete_episodes(
|
||||||
|
sample_dataset,
|
||||||
|
episode_indices=[0, 4],
|
||||||
|
output_dir=output_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert new_dataset.meta.total_episodes == 3
|
||||||
|
assert new_dataset.meta.total_frames == 30
|
||||||
|
|
||||||
|
episode_indices = sorted({int(idx.item()) for idx in new_dataset.hf_dataset["episode_index"]})
|
||||||
|
assert episode_indices == [0, 1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_all_episodes_assigned(sample_dataset, tmp_path):
|
||||||
|
"""Test that all episodes can be explicitly assigned to splits."""
|
||||||
|
splits = {
|
||||||
|
"split1": [0, 1],
|
||||||
|
"split2": [2, 3],
|
||||||
|
"split3": [4],
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||||
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||||
|
):
|
||||||
|
mock_get_safe_version.return_value = "v3.0"
|
||||||
|
|
||||||
|
def mock_snapshot(repo_id, **kwargs):
|
||||||
|
for split_name in splits:
|
||||||
|
if split_name in repo_id:
|
||||||
|
return str(tmp_path / f"{sample_dataset.repo_id}_{split_name}")
|
||||||
|
return str(kwargs.get("local_dir", tmp_path))
|
||||||
|
|
||||||
|
mock_snapshot_download.side_effect = mock_snapshot
|
||||||
|
|
||||||
|
result = split_dataset(
|
||||||
|
sample_dataset,
|
||||||
|
splits=splits,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
total_episodes = sum(ds.meta.total_episodes for ds in result.values())
|
||||||
|
assert total_episodes == sample_dataset.meta.total_episodes
|
||||||
Reference in New Issue
Block a user