* feat(dataset-tools): add dataset utilities and example script - Introduced dataset tools for LeRobotDataset, including functions for deleting episodes, splitting datasets, adding/removing features, and merging datasets. - Added an example script demonstrating the usage of these utilities. - Implemented comprehensive tests for all new functionalities to ensure reliability and correctness. * style fixes * move example to dataset dir * missing lisence * fixes mostly path * clean comments * move tests to functions instead of class based * - fix video editting, decode, delete frames and rencode video - copy unchanged video and parquet files to avoid recreating the entire dataset * Fortify tooling tests * Fix type issue resulting from saving numpy arrays with shape 3,1,1 * added lerobot_edit_dataset * - revert changes in examples - remove hardcoded split names * update comment * fix comment add lerobot-edit-dataset shortcut * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> * style nit after copilot review * fix: bug in dataset root when editing the dataset in place (without setting new_repo_id * Fix bug in aggregate.py when accumelating video timestamps; add tests to fortify aggregate videos * Added missing output repo id * migrate delete episode to using pyav instead of decoding, writing frames to disk and encoding again. Co-authored-by: Caroline Pascal <caroline8.pascal@gmail.com> * added modified suffix in case repo_id is not set in delete_episode * adding docs for dataset tools * bump av version and add back time_base assignment * linter * modified push_to_hub logic in lerobot_edit_dataset * fix(progress bar): fixing the progress bar issue in dataset tools * chore(concatenate): removing no longer needed concatenate_datasets usage * fix(file sizes forwarding): forwarding files and chunk sizes in metadata info when splitting and aggregating datasets * style fix * refactor(aggregate): Fix video indexing and timestamp bugs in dataset merging There were three critical bugs in aggregate.py that prevented correct dataset merging: 1. Video file indices: Changed from += to = assignment to correctly reference merged video files 2. Video timestamps: Implemented per-source-file offset tracking to maintain continuous timestamps when merging split datasets (was causing non-monotonic timestamp warnings) 3. File rotation offsets: Store timestamp offsets after rotation decision to prevent out-of-bounds frame access (was causing "Invalid frame index" errors with small file size limits) Changes: - Updated update_meta_data() to apply per-source-file timestamp offsets - Updated aggregate_videos() to track offsets correctly during file rotation - Added get_video_duration_in_s import for duration calculation * Improved docs for split dataset and added a check for the possible case that the split size results in zero episodes * chore(docs): update merge documentation details Signed-off-by: Steven Palma <imstevenpmwork@ieee.org> --------- Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com> Co-authored-by: Jack Vial <vialjack@gmail.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
383 lines
16 KiB
Python
383 lines
16 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
from lerobot.datasets.aggregate import aggregate_datasets
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|
from tests.fixtures.constants import DUMMY_REPO_ID
|
|
|
|
|
|
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
|
"""Test that total number of episodes and frames are correctly aggregated."""
|
|
assert aggr_ds.num_episodes == expected_episodes, (
|
|
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
|
|
)
|
|
assert aggr_ds.num_frames == expected_frames, (
|
|
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
|
|
)
|
|
|
|
|
|
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
|
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
|
|
keys_to_ignore = ["episode_index", "index", "timestamp"]
|
|
|
|
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
|
|
aggr_first_item = aggr_ds[0]
|
|
ds_0_first_item = ds_0[0]
|
|
|
|
# Compare all keys except episode_index and index which should be updated
|
|
for key in ds_0_first_item:
|
|
if key not in keys_to_ignore:
|
|
# Handle both tensor and non-tensor data
|
|
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
|
|
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
|
|
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
|
)
|
|
else:
|
|
assert aggr_first_item[key] == ds_0_first_item[key], (
|
|
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
|
)
|
|
|
|
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
|
|
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
|
|
ds_0_last_item = ds_0[-1]
|
|
|
|
for key in ds_0_last_item:
|
|
if key not in keys_to_ignore:
|
|
# Handle both tensor and non-tensor data
|
|
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
|
|
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
|
|
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
|
)
|
|
else:
|
|
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
|
|
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
|
)
|
|
|
|
# Test second part of dataset corresponds to ds_1
|
|
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
|
|
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
|
|
ds_1_first_item = ds_1[0]
|
|
|
|
for key in ds_1_first_item:
|
|
if key not in keys_to_ignore:
|
|
# Handle both tensor and non-tensor data
|
|
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
|
|
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
|
|
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
|
)
|
|
else:
|
|
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
|
|
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
|
)
|
|
|
|
# Check last item matches ds_1[-1]
|
|
aggr_last_item = aggr_ds[-1]
|
|
ds_1_last_item = ds_1[-1]
|
|
|
|
for key in ds_1_last_item:
|
|
if key not in keys_to_ignore:
|
|
# Handle both tensor and non-tensor data
|
|
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
|
|
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
|
|
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
|
)
|
|
else:
|
|
assert aggr_last_item[key] == ds_1_last_item[key], (
|
|
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
|
)
|
|
|
|
|
|
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
|
"""Test that metadata is correctly aggregated."""
|
|
# Test basic info
|
|
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
|
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
|
"Robot type should be the same"
|
|
)
|
|
|
|
# Test features are the same
|
|
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
|
|
|
# Test tasks aggregation
|
|
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
|
actual_tasks = set(aggr_ds.meta.tasks.index)
|
|
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
|
|
|
|
|
|
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
|
|
"""Test that episode indices are correctly updated after aggregation."""
|
|
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
|
|
for i in range(len(ds_0)):
|
|
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
|
|
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
|
|
)
|
|
|
|
def ds1_episodes_condition(ep_idx):
|
|
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
|
|
|
|
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
|
|
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
|
expected_min_episode_idx = ds_0.num_episodes
|
|
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
|
|
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
|
|
)
|
|
|
|
|
|
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
|
"""Test that video frames are correctly preserved and frame indices are updated."""
|
|
|
|
def visual_frames_equal(frame1, frame2):
|
|
return torch.allclose(frame1, frame2)
|
|
|
|
video_keys = list(
|
|
filter(
|
|
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
|
aggr_ds.meta.info["features"].keys(),
|
|
)
|
|
)
|
|
|
|
# Test the section corresponding to the first dataset (ds_0)
|
|
for i in range(len(ds_0)):
|
|
assert aggr_ds[i]["index"] == i, (
|
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
|
)
|
|
for key in video_keys:
|
|
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
|
|
f"Visual frames at position {i} should be equal between aggregated and ds_0"
|
|
)
|
|
|
|
# Test the section corresponding to the second dataset (ds_1)
|
|
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
|
# The frame index in the aggregated dataset should also match its position.
|
|
assert aggr_ds[i]["index"] == i, (
|
|
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
|
)
|
|
for key in video_keys:
|
|
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
|
f"Visual frames at position {i} should be equal between aggregated and ds_1"
|
|
)
|
|
|
|
|
|
def assert_dataset_iteration_works(aggr_ds):
|
|
"""Test that we can iterate through the entire dataset without errors."""
|
|
for _ in aggr_ds:
|
|
pass
|
|
|
|
|
|
def assert_video_timestamps_within_bounds(aggr_ds):
|
|
"""Test that all video timestamps are within valid bounds for their respective video files.
|
|
|
|
This catches bugs where timestamps point to frames beyond the actual video length,
|
|
which would cause "Invalid frame index" errors during data loading.
|
|
"""
|
|
try:
|
|
from torchcodec.decoders import VideoDecoder
|
|
except ImportError:
|
|
return
|
|
|
|
for ep_idx in range(aggr_ds.num_episodes):
|
|
ep = aggr_ds.meta.episodes[ep_idx]
|
|
|
|
for vid_key in aggr_ds.meta.video_keys:
|
|
from_ts = ep[f"videos/{vid_key}/from_timestamp"]
|
|
to_ts = ep[f"videos/{vid_key}/to_timestamp"]
|
|
video_path = aggr_ds.root / aggr_ds.meta.get_video_file_path(ep_idx, vid_key)
|
|
|
|
if not video_path.exists():
|
|
continue
|
|
|
|
from_frame_idx = round(from_ts * aggr_ds.fps)
|
|
to_frame_idx = round(to_ts * aggr_ds.fps)
|
|
|
|
try:
|
|
decoder = VideoDecoder(str(video_path))
|
|
num_frames = len(decoder)
|
|
|
|
# Verify timestamps don't exceed video bounds
|
|
assert from_frame_idx >= 0, (
|
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) < 0"
|
|
)
|
|
assert from_frame_idx < num_frames, (
|
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= video frames ({num_frames})"
|
|
)
|
|
assert to_frame_idx <= num_frames, (
|
|
f"Episode {ep_idx}, {vid_key}: to_frame_idx ({to_frame_idx}) > video frames ({num_frames})"
|
|
)
|
|
assert from_frame_idx < to_frame_idx, (
|
|
f"Episode {ep_idx}, {vid_key}: from_frame_idx ({from_frame_idx}) >= to_frame_idx ({to_frame_idx})"
|
|
)
|
|
except Exception as e:
|
|
raise AssertionError(
|
|
f"Failed to verify timestamps for episode {ep_idx}, {vid_key}: {e}"
|
|
) from e
|
|
|
|
|
|
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
|
"""Test basic aggregation functionality with standard parameters."""
|
|
ds_0_num_frames = 400
|
|
ds_1_num_frames = 800
|
|
ds_0_num_episodes = 10
|
|
ds_1_num_episodes = 25
|
|
|
|
# Create two datasets with different number of frames and episodes
|
|
ds_0 = lerobot_dataset_factory(
|
|
root=tmp_path / "test_0",
|
|
repo_id=f"{DUMMY_REPO_ID}_0",
|
|
total_episodes=ds_0_num_episodes,
|
|
total_frames=ds_0_num_frames,
|
|
)
|
|
ds_1 = lerobot_dataset_factory(
|
|
root=tmp_path / "test_1",
|
|
repo_id=f"{DUMMY_REPO_ID}_1",
|
|
total_episodes=ds_1_num_episodes,
|
|
total_frames=ds_1_num_frames,
|
|
)
|
|
|
|
aggregate_datasets(
|
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
|
roots=[ds_0.root, ds_1.root],
|
|
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
|
aggr_root=tmp_path / "test_aggr",
|
|
)
|
|
|
|
# Mock the revision to prevent Hub calls during dataset loading
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
|
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
|
|
|
# Run all assertion functions
|
|
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
|
expected_total_frames = ds_0.num_frames + ds_1.num_frames
|
|
|
|
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
|
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
|
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
|
assert_video_timestamps_within_bounds(aggr_ds)
|
|
assert_dataset_iteration_works(aggr_ds)
|
|
|
|
|
|
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
|
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
|
ds_0_num_episodes = ds_1_num_episodes = 10
|
|
ds_0_num_frames = ds_1_num_frames = 400
|
|
|
|
ds_0 = lerobot_dataset_factory(
|
|
root=tmp_path / "small_0",
|
|
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
|
total_episodes=ds_0_num_episodes,
|
|
total_frames=ds_0_num_frames,
|
|
)
|
|
ds_1 = lerobot_dataset_factory(
|
|
root=tmp_path / "small_1",
|
|
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
|
total_episodes=ds_1_num_episodes,
|
|
total_frames=ds_1_num_frames,
|
|
)
|
|
|
|
# Use the new configurable parameters to force file rotation
|
|
aggregate_datasets(
|
|
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
|
roots=[ds_0.root, ds_1.root],
|
|
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
|
|
aggr_root=tmp_path / "small_aggr",
|
|
# Tiny file size to trigger new file instantiation
|
|
data_files_size_in_mb=0.01,
|
|
video_files_size_in_mb=0.1,
|
|
)
|
|
|
|
# Mock the revision to prevent Hub calls during dataset loading
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
|
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
|
|
|
# Verify aggregation worked correctly despite file size constraints
|
|
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
|
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
|
|
|
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
|
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
|
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
|
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
|
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
|
assert_video_timestamps_within_bounds(aggr_ds)
|
|
assert_dataset_iteration_works(aggr_ds)
|
|
|
|
# Check that multiple files were actually created due to small size limits
|
|
data_dir = tmp_path / "small_aggr" / "data"
|
|
video_dir = tmp_path / "small_aggr" / "videos"
|
|
|
|
if data_dir.exists():
|
|
parquet_files = list(data_dir.rglob("*.parquet"))
|
|
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
|
|
|
|
if video_dir.exists():
|
|
video_files = list(video_dir.rglob("*.mp4"))
|
|
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
|
|
|
|
|
def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory):
|
|
"""Regression test for video timestamp bug when merging datasets.
|
|
|
|
This test specifically checks that video timestamps are correctly calculated
|
|
and accumulated when merging multiple datasets.
|
|
"""
|
|
datasets = []
|
|
for i in range(3):
|
|
ds = lerobot_dataset_factory(
|
|
root=tmp_path / f"regression_{i}",
|
|
repo_id=f"{DUMMY_REPO_ID}_regression_{i}",
|
|
total_episodes=2,
|
|
total_frames=100,
|
|
)
|
|
datasets.append(ds)
|
|
|
|
aggregate_datasets(
|
|
repo_ids=[ds.repo_id for ds in datasets],
|
|
roots=[ds.root for ds in datasets],
|
|
aggr_repo_id=f"{DUMMY_REPO_ID}_regression_aggr",
|
|
aggr_root=tmp_path / "regression_aggr",
|
|
)
|
|
|
|
with (
|
|
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
|
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
|
):
|
|
mock_get_safe_version.return_value = "v3.0"
|
|
mock_snapshot_download.return_value = str(tmp_path / "regression_aggr")
|
|
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_regression_aggr", root=tmp_path / "regression_aggr")
|
|
|
|
assert_video_timestamps_within_bounds(aggr_ds)
|
|
|
|
for i in range(len(aggr_ds)):
|
|
item = aggr_ds[i]
|
|
for key in aggr_ds.meta.video_keys:
|
|
assert key in item, f"Video key {key} missing from item {i}"
|
|
assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}"
|