Dataset v3 (#1412)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Remi Cadene <re.cadene@gmail.com> Co-authored-by: Tavish <tavish9.chen@gmail.com> Co-authored-by: fracapuano <francesco.capuano@huggingface.co> Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
@@ -47,38 +47,22 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
|
||||
)
|
||||
|
||||
# save 2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
i = dataset.meta.episodes["dataset_from_index"][0]
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(dataset.meta.episodes["dataset_to_index"][0] - dataset.meta.episodes["dataset_from_index"][0]) / 2
|
||||
)
|
||||
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
|
||||
|
||||
# save 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
i = dataset.meta.episodes["dataset_to_index"][0]
|
||||
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
|
||||
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # save 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
|
||||
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
|
||||
|
||||
# # save 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
# # save 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
|
||||
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for dataset in [
|
||||
|
||||
292
tests/datasets/test_aggregate.py
Normal file
292
tests/datasets/test_aggregate.py
Normal file
@@ -0,0 +1,292 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
|
||||
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
|
||||
"""Test that total number of episodes and frames are correctly aggregated."""
|
||||
assert aggr_ds.num_episodes == expected_episodes, (
|
||||
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
|
||||
)
|
||||
assert aggr_ds.num_frames == expected_frames, (
|
||||
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
|
||||
)
|
||||
|
||||
|
||||
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
|
||||
keys_to_ignore = ["episode_index", "index", "timestamp"]
|
||||
|
||||
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
|
||||
aggr_first_item = aggr_ds[0]
|
||||
ds_0_first_item = ds_0[0]
|
||||
|
||||
# Compare all keys except episode_index and index which should be updated
|
||||
for key in ds_0_first_item:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
|
||||
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
|
||||
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
else:
|
||||
assert aggr_first_item[key] == ds_0_first_item[key], (
|
||||
f"First item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
|
||||
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
|
||||
ds_0_last_item = ds_0[-1]
|
||||
|
||||
for key in ds_0_last_item:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
|
||||
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
|
||||
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
else:
|
||||
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
|
||||
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test second part of dataset corresponds to ds_1
|
||||
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
|
||||
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
|
||||
ds_1_first_item = ds_1[0]
|
||||
|
||||
for key in ds_1_first_item:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
|
||||
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
|
||||
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
else:
|
||||
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
|
||||
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
|
||||
# Check last item matches ds_1[-1]
|
||||
aggr_last_item = aggr_ds[-1]
|
||||
ds_1_last_item = ds_1[-1]
|
||||
|
||||
for key in ds_1_last_item:
|
||||
if key not in keys_to_ignore:
|
||||
# Handle both tensor and non-tensor data
|
||||
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
|
||||
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
|
||||
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
else:
|
||||
assert aggr_last_item[key] == ds_1_last_item[key], (
|
||||
f"Last item key '{key}' doesn't match between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||
"""Test that metadata is correctly aggregated."""
|
||||
# Test basic info
|
||||
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
# Test features are the same
|
||||
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
|
||||
|
||||
# Test tasks aggregation
|
||||
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
|
||||
actual_tasks = set(aggr_ds.meta.tasks.index)
|
||||
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
|
||||
|
||||
|
||||
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
|
||||
"""Test that episode indices are correctly updated after aggregation."""
|
||||
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
|
||||
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
|
||||
)
|
||||
|
||||
def ds1_episodes_condition(ep_idx):
|
||||
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
|
||||
|
||||
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
expected_min_episode_idx = ds_0.num_episodes
|
||||
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
|
||||
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
|
||||
)
|
||||
|
||||
|
||||
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
"""Test that video frames are correctly preserved and frame indices are updated."""
|
||||
|
||||
def visual_frames_equal(frame1, frame2):
|
||||
return torch.allclose(frame1, frame2)
|
||||
|
||||
video_keys = list(
|
||||
filter(
|
||||
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||
aggr_ds.meta.info["features"].keys(),
|
||||
)
|
||||
)
|
||||
|
||||
# Test the section corresponding to the first dataset (ds_0)
|
||||
for i in range(len(ds_0)):
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in video_keys:
|
||||
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
|
||||
f"Visual frames at position {i} should be equal between aggregated and ds_0"
|
||||
)
|
||||
|
||||
# Test the section corresponding to the second dataset (ds_1)
|
||||
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
|
||||
# The frame index in the aggregated dataset should also match its position.
|
||||
assert aggr_ds[i]["index"] == i, (
|
||||
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
|
||||
)
|
||||
for key in video_keys:
|
||||
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
|
||||
f"Visual frames at position {i} should be equal between aggregated and ds_1"
|
||||
)
|
||||
|
||||
|
||||
def assert_dataset_iteration_works(aggr_ds):
|
||||
"""Test that we can iterate through the entire dataset without errors."""
|
||||
for _ in aggr_ds:
|
||||
pass
|
||||
|
||||
|
||||
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
|
||||
"""Test basic aggregation functionality with standard parameters."""
|
||||
ds_0_num_frames = 400
|
||||
ds_1_num_frames = 800
|
||||
ds_0_num_episodes = 10
|
||||
ds_1_num_episodes = 25
|
||||
|
||||
# Create two datasets with different number of frames and episodes
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "test_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
)
|
||||
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
|
||||
aggr_root=tmp_path / "test_aggr",
|
||||
)
|
||||
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
|
||||
|
||||
# Run all assertion functions
|
||||
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
|
||||
expected_total_frames = ds_0.num_frames + ds_1.num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
|
||||
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
|
||||
"""Test aggregation with small file size limits to force file rotation/sharding."""
|
||||
ds_0_num_episodes = ds_1_num_episodes = 10
|
||||
ds_0_num_frames = ds_1_num_frames = 400
|
||||
|
||||
ds_0 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_0",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_0",
|
||||
total_episodes=ds_0_num_episodes,
|
||||
total_frames=ds_0_num_frames,
|
||||
)
|
||||
ds_1 = lerobot_dataset_factory(
|
||||
root=tmp_path / "small_1",
|
||||
repo_id=f"{DUMMY_REPO_ID}_small_1",
|
||||
total_episodes=ds_1_num_episodes,
|
||||
total_frames=ds_1_num_frames,
|
||||
)
|
||||
|
||||
# Use the new configurable parameters to force file rotation
|
||||
aggregate_datasets(
|
||||
repo_ids=[ds_0.repo_id, ds_1.repo_id],
|
||||
roots=[ds_0.root, ds_1.root],
|
||||
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
|
||||
aggr_root=tmp_path / "small_aggr",
|
||||
# Tiny file size to trigger new file instantiation
|
||||
data_files_size_in_mb=0.01,
|
||||
video_files_size_in_mb=0.1,
|
||||
)
|
||||
|
||||
# Mock the revision to prevent Hub calls during dataset loading
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
|
||||
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
|
||||
|
||||
# Verify aggregation worked correctly despite file size constraints
|
||||
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
|
||||
expected_total_frames = ds_0_num_frames + ds_1_num_frames
|
||||
|
||||
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
|
||||
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
|
||||
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
|
||||
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
|
||||
assert_dataset_iteration_works(aggr_ds)
|
||||
|
||||
# Check that multiple files were actually created due to small size limits
|
||||
data_dir = tmp_path / "small_aggr" / "data"
|
||||
video_dir = tmp_path / "small_aggr" / "videos"
|
||||
|
||||
if data_dir.exists():
|
||||
parquet_files = list(data_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
|
||||
|
||||
if video_dir.exists():
|
||||
video_files = list(video_dir.rglob("*.mp4"))
|
||||
assert len(video_files) > 1, "Small file size limits should create multiple video files"
|
||||
@@ -13,10 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
|
||||
@@ -37,13 +35,19 @@ from lerobot.datasets.lerobot_dataset import (
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
unflatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -69,12 +73,17 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
robot = make_robot_from_config(MockRobotConfig())
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", True)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
|
||||
dataset_create = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create
|
||||
)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
@@ -99,13 +108,41 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
|
||||
# and test the small resulting function that validates the features
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
)
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1)})
|
||||
|
||||
|
||||
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
|
||||
):
|
||||
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
|
||||
dataset.add_frame({"task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -114,7 +151,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
|
||||
with pytest.raises(
|
||||
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -123,7 +160,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
|
||||
with pytest.raises(
|
||||
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -133,7 +170,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -145,7 +182,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
|
||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": 1.0}, task="Dummy task")
|
||||
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -155,7 +192,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
||||
ValueError,
|
||||
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
|
||||
):
|
||||
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
|
||||
@@ -167,13 +204,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
|
||||
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
|
||||
),
|
||||
):
|
||||
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
|
||||
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert len(dataset) == 1
|
||||
@@ -185,7 +222,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2])
|
||||
@@ -194,7 +231,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4])
|
||||
@@ -203,7 +240,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
|
||||
@@ -212,7 +249,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
|
||||
@@ -221,7 +258,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
|
||||
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
|
||||
@@ -230,7 +267,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
|
||||
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["state"].ndim == 0
|
||||
@@ -239,7 +276,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
|
||||
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
|
||||
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["caption"] == "Dummy caption"
|
||||
@@ -254,7 +291,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
|
||||
),
|
||||
):
|
||||
c, h, w = DUMMY_CHW
|
||||
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
|
||||
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
|
||||
|
||||
|
||||
def test_add_frame_image_wrong_range(image_dataset):
|
||||
@@ -267,14 +304,14 @@ def test_add_frame_image_wrong_range(image_dataset):
|
||||
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
|
||||
"""
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
|
||||
with pytest.raises(FileNotFoundError):
|
||||
dataset.save_episode()
|
||||
|
||||
|
||||
def test_add_frame_image(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
@@ -282,7 +319,7 @@ def test_add_frame_image(image_dataset):
|
||||
|
||||
def test_add_frame_image_h_w_c(image_dataset):
|
||||
dataset = image_dataset
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
|
||||
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
@@ -291,7 +328,7 @@ def test_add_frame_image_h_w_c(image_dataset):
|
||||
def test_add_frame_image_uint8(image_dataset):
|
||||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": image}, task="Dummy task")
|
||||
dataset.add_frame({"image": image, "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
@@ -300,7 +337,7 @@ def test_add_frame_image_uint8(image_dataset):
|
||||
def test_add_frame_image_pil(image_dataset):
|
||||
dataset = image_dataset
|
||||
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
|
||||
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
|
||||
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||
@@ -319,6 +356,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
# TODO(rcadene):
|
||||
# - [ ] fix code so that old test_factory + backward pass
|
||||
# - [ ] write new unit tests to test save_episode + getitem
|
||||
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
|
||||
# - [ ]
|
||||
# - [ ] remove old tests
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
@@ -338,9 +382,8 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
|
||||
env=make_env_config(env_name),
|
||||
policy=make_policy_config(policy_name, push_to_hub=False),
|
||||
policy=make_policy_config(policy_name),
|
||||
)
|
||||
cfg.validate()
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
@@ -427,30 +470,6 @@ def test_multidataset_frames():
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -497,38 +516,22 @@ def test_backward_compatibility(repo_id):
|
||||
)
|
||||
|
||||
# test2 first frames of first episode
|
||||
i = dataset.episode_data_index["from"][0].item()
|
||||
i = dataset.meta.episodes[0]["dataset_from_index"]
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 frames at the middle of first episode
|
||||
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
|
||||
i = int(
|
||||
(dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
|
||||
)
|
||||
load_and_compare(i)
|
||||
load_and_compare(i + 1)
|
||||
|
||||
# test 2 last frames of first episode
|
||||
i = dataset.episode_data_index["to"][0].item()
|
||||
i = dataset.meta.episodes[0]["dataset_to_index"]
|
||||
load_and_compare(i - 2)
|
||||
load_and_compare(i - 1)
|
||||
|
||||
# TODO(rcadene): Enable testing on second and last episode
|
||||
# We currently cant because our test dataset only contains the first episode
|
||||
|
||||
# # test 2 first frames of second episode
|
||||
# i = dataset.episode_data_index["from"][1].item()
|
||||
# load_and_compare(i)
|
||||
# load_and_compare(i + 1)
|
||||
|
||||
# # test 2 last frames of second episode
|
||||
# i = dataset.episode_data_index["to"][1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
# # test 2 last frames of last episode
|
||||
# i = dataset.episode_data_index["to"][-1].item()
|
||||
# load_and_compare(i - 2)
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires internet access")
|
||||
def test_create_branch():
|
||||
@@ -556,18 +559,499 @@ def test_create_branch():
|
||||
api.delete_repo(repo_id, repo_type=repo_type)
|
||||
|
||||
|
||||
def test_dataset_feature_with_forward_slash_raises_error():
|
||||
# make sure dir does not exist
|
||||
from lerobot.constants import HF_LEROBOT_HOME
|
||||
def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
|
||||
"""Test the _check_cached_episodes_sufficient method of LeRobotDataset."""
|
||||
# Create a dataset with 5 episodes (0-4)
|
||||
dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "test",
|
||||
total_episodes=5,
|
||||
total_frames=200,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
|
||||
# make sure does not exist
|
||||
if dataset_dir.exists():
|
||||
dataset_dir.rmdir()
|
||||
# Test hf_dataset is None
|
||||
dataset.hf_dataset = None
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LeRobotDataset.create(
|
||||
repo_id="lerobot/test/with/slash",
|
||||
fps=30,
|
||||
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
|
||||
# Test hf_dataset is empty
|
||||
import datasets
|
||||
|
||||
empty_features = get_hf_features_from_features(dataset.features)
|
||||
dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
{key: [] for key in empty_features}, features=empty_features
|
||||
)
|
||||
dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
assert dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Restore the original dataset for remaining tests
|
||||
dataset.hf_dataset = dataset.load_hf_dataset()
|
||||
|
||||
# Test all episodes requested (self.episodes = None) and all are available
|
||||
dataset.episodes = None
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test specific episodes requested that are all available
|
||||
dataset.episodes = [0, 2, 4]
|
||||
assert dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test request episodes that don't exist in the cached dataset
|
||||
# Create a dataset with only episodes 0, 1, 2
|
||||
limited_dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "limited",
|
||||
total_episodes=3,
|
||||
total_frames=120,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# Request episodes that include non-existent ones
|
||||
limited_dataset.episodes = [0, 1, 2, 3, 4]
|
||||
assert limited_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
|
||||
# First create the full dataset structure
|
||||
sparse_dataset = lerobot_dataset_factory(
|
||||
root=tmp_path / "sparse",
|
||||
total_episodes=5,
|
||||
total_frames=200,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
# Manually filter hf_dataset to only include episodes 0, 2, 4
|
||||
episode_indices = sparse_dataset.hf_dataset["episode_index"]
|
||||
mask = torch.zeros(len(episode_indices), dtype=torch.bool)
|
||||
for ep in [0, 2, 4]:
|
||||
mask |= torch.tensor(episode_indices) == ep
|
||||
|
||||
# Create a filtered dataset
|
||||
filtered_data = {}
|
||||
# Find image keys by checking features
|
||||
image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"]
|
||||
|
||||
for key in sparse_dataset.hf_dataset.column_names:
|
||||
values = sparse_dataset.hf_dataset[key]
|
||||
# Filter values based on mask
|
||||
filtered_values = [val for i, val in enumerate(values) if mask[i]]
|
||||
|
||||
# Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset
|
||||
if key in image_keys and len(filtered_values) > 0:
|
||||
# Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC)
|
||||
filtered_values = [
|
||||
(val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values
|
||||
]
|
||||
|
||||
filtered_data[key] = filtered_values
|
||||
|
||||
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
|
||||
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
|
||||
)
|
||||
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
# Test requesting all episodes when only some are cached
|
||||
sparse_dataset.episodes = None
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
# Test requesting only the available episodes
|
||||
sparse_dataset.episodes = [0, 2, 4]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is True
|
||||
|
||||
# Test requesting a mix of available and unavailable episodes
|
||||
sparse_dataset.episodes = [0, 1, 2]
|
||||
assert sparse_dataset._check_cached_episodes_sufficient() is False
|
||||
|
||||
|
||||
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
|
||||
},
|
||||
}
|
||||
|
||||
# Create dataset with default chunk settings
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
# Test initial default values
|
||||
initial_settings = dataset.meta.get_chunk_settings()
|
||||
assert initial_settings["chunks_size"] == DEFAULT_CHUNK_SIZE
|
||||
assert initial_settings["data_files_size_in_mb"] == DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
assert initial_settings["video_files_size_in_mb"] == DEFAULT_VIDEO_FILE_SIZE_IN_MB
|
||||
|
||||
# Test updating all settings at once
|
||||
new_chunks_size = 2000
|
||||
new_data_size = 200
|
||||
new_video_size = 1000
|
||||
|
||||
dataset.meta.update_chunk_settings(
|
||||
chunks_size=new_chunks_size,
|
||||
data_files_size_in_mb=new_data_size,
|
||||
video_files_size_in_mb=new_video_size,
|
||||
)
|
||||
|
||||
# Verify settings were updated
|
||||
updated_settings = dataset.meta.get_chunk_settings()
|
||||
assert updated_settings["chunks_size"] == new_chunks_size
|
||||
assert updated_settings["data_files_size_in_mb"] == new_data_size
|
||||
assert updated_settings["video_files_size_in_mb"] == new_video_size
|
||||
|
||||
# Test updating individual settings
|
||||
dataset.meta.update_chunk_settings(chunks_size=1500)
|
||||
settings_after_partial = dataset.meta.get_chunk_settings()
|
||||
assert settings_after_partial["chunks_size"] == 1500
|
||||
assert settings_after_partial["data_files_size_in_mb"] == new_data_size
|
||||
assert settings_after_partial["video_files_size_in_mb"] == new_video_size
|
||||
|
||||
# Test updating only data file size
|
||||
dataset.meta.update_chunk_settings(data_files_size_in_mb=150)
|
||||
settings_after_data = dataset.meta.get_chunk_settings()
|
||||
assert settings_after_data["chunks_size"] == 1500
|
||||
assert settings_after_data["data_files_size_in_mb"] == 150
|
||||
assert settings_after_data["video_files_size_in_mb"] == new_video_size
|
||||
|
||||
# Test updating only video file size
|
||||
dataset.meta.update_chunk_settings(video_files_size_in_mb=800)
|
||||
settings_after_video = dataset.meta.get_chunk_settings()
|
||||
assert settings_after_video["chunks_size"] == 1500
|
||||
assert settings_after_video["data_files_size_in_mb"] == 150
|
||||
assert settings_after_video["video_files_size_in_mb"] == 800
|
||||
|
||||
# Test that settings persist in the info file
|
||||
info_path = dataset.root / "meta" / "info.json"
|
||||
assert info_path.exists()
|
||||
|
||||
# Verify the underlying metadata properties
|
||||
assert dataset.meta.chunks_size == 1500
|
||||
assert dataset.meta.data_files_size_in_mb == 150
|
||||
assert dataset.meta.video_files_size_in_mb == 800
|
||||
|
||||
# Test error handling for invalid values
|
||||
with pytest.raises(ValueError, match="chunks_size must be positive"):
|
||||
dataset.meta.update_chunk_settings(chunks_size=0)
|
||||
|
||||
with pytest.raises(ValueError, match="chunks_size must be positive"):
|
||||
dataset.meta.update_chunk_settings(chunks_size=-100)
|
||||
|
||||
with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"):
|
||||
dataset.meta.update_chunk_settings(data_files_size_in_mb=0)
|
||||
|
||||
with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"):
|
||||
dataset.meta.update_chunk_settings(data_files_size_in_mb=-50)
|
||||
|
||||
with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"):
|
||||
dataset.meta.update_chunk_settings(video_files_size_in_mb=0)
|
||||
|
||||
with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"):
|
||||
dataset.meta.update_chunk_settings(video_files_size_in_mb=-200)
|
||||
|
||||
# Test calling with None values (should not change anything)
|
||||
settings_before_none = dataset.meta.get_chunk_settings()
|
||||
dataset.meta.update_chunk_settings(
|
||||
chunks_size=None, data_files_size_in_mb=None, video_files_size_in_mb=None
|
||||
)
|
||||
settings_after_none = dataset.meta.get_chunk_settings()
|
||||
assert settings_before_none == settings_after_none
|
||||
|
||||
# Test metadata direct access
|
||||
meta_settings = dataset.meta.get_chunk_settings()
|
||||
assert meta_settings == dataset.meta.get_chunk_settings()
|
||||
|
||||
# Test updating via metadata directly
|
||||
dataset.meta.update_chunk_settings(chunks_size=3000)
|
||||
assert dataset.meta.get_chunk_settings()["chunks_size"] == 3000
|
||||
|
||||
|
||||
def test_update_chunk_settings_video_dataset(tmp_path):
|
||||
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
# Create video dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID, fps=30, features=features, root=tmp_path / "video_test", use_videos=True
|
||||
)
|
||||
|
||||
# Test that video-specific settings work
|
||||
original_video_size = dataset.meta.get_chunk_settings()["video_files_size_in_mb"]
|
||||
new_video_size = original_video_size * 2
|
||||
|
||||
dataset.meta.update_chunk_settings(video_files_size_in_mb=new_video_size)
|
||||
assert dataset.meta.get_chunk_settings()["video_files_size_in_mb"] == new_video_size
|
||||
assert dataset.meta.video_files_size_in_mb == new_video_size
|
||||
|
||||
|
||||
def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that all frames have correct episode indices across multiple episodes."""
|
||||
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create 3 episodes with different lengths
|
||||
num_episodes = 3
|
||||
frames_per_episode = [10, 15, 8]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load the dataset and check episode indices
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check specific frames across episode boundaries
|
||||
cumulative = 0
|
||||
for ep_idx, ep_length in enumerate(frames_per_episode):
|
||||
# Check start, middle, and end of each episode
|
||||
start_frame = cumulative
|
||||
middle_frame = cumulative + ep_length // 2
|
||||
end_frame = cumulative + ep_length - 1
|
||||
|
||||
for frame_idx in [start_frame, middle_frame, end_frame]:
|
||||
frame_data = loaded_dataset[frame_idx]
|
||||
actual_ep_idx = frame_data["episode_index"].item()
|
||||
assert actual_ep_idx == ep_idx, (
|
||||
f"Frame {frame_idx} has episode_index {actual_ep_idx}, should be {ep_idx}"
|
||||
)
|
||||
|
||||
cumulative += ep_length
|
||||
|
||||
# Check episode index distribution
|
||||
all_episode_indices = [loaded_dataset[i]["episode_index"].item() for i in range(len(loaded_dataset))]
|
||||
from collections import Counter
|
||||
|
||||
distribution = Counter(all_episode_indices)
|
||||
expected_dist = {i: frames_per_episode[i] for i in range(num_episodes)}
|
||||
|
||||
assert dict(distribution) == expected_dist, (
|
||||
f"Episode distribution {dict(distribution)} != expected {expected_dist}"
|
||||
)
|
||||
|
||||
|
||||
def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test episode metadata consistency across multiple episodes."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 4
|
||||
frames_per_episode = [20, 35, 10, 25]
|
||||
tasks = ["pick", "place", "pick", "place"]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
# Load and validate episode metadata
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
assert loaded_dataset.meta.total_episodes == num_episodes
|
||||
assert loaded_dataset.meta.total_frames == sum(frames_per_episode)
|
||||
|
||||
cumulative_frames = 0
|
||||
for episode_idx in range(num_episodes):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
|
||||
# Check basic episode properties
|
||||
assert episode_metadata["episode_index"] == episode_idx
|
||||
assert episode_metadata["length"] == frames_per_episode[episode_idx]
|
||||
assert episode_metadata["tasks"] == [tasks[episode_idx]]
|
||||
|
||||
# Check dataset indices
|
||||
expected_from = cumulative_frames
|
||||
expected_to = cumulative_frames + frames_per_episode[episode_idx]
|
||||
|
||||
assert episode_metadata["dataset_from_index"] == expected_from
|
||||
assert episode_metadata["dataset_to_index"] == expected_to
|
||||
|
||||
cumulative_frames += frames_per_episode[episode_idx]
|
||||
|
||||
|
||||
def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that episodes have no gaps or overlaps in their data indices."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 5
|
||||
frames_per_episode = [12, 8, 20, 15, 5]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check data consistency - no gaps or overlaps
|
||||
cumulative_check = 0
|
||||
for episode_idx in range(num_episodes):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
from_idx = episode_metadata["dataset_from_index"]
|
||||
to_idx = episode_metadata["dataset_to_index"]
|
||||
|
||||
# Check that episode starts exactly where previous ended
|
||||
assert from_idx == cumulative_check, (
|
||||
f"Episode {episode_idx} starts at {from_idx}, expected {cumulative_check}"
|
||||
)
|
||||
|
||||
# Check that episode length matches expected
|
||||
actual_length = to_idx - from_idx
|
||||
expected_length = frames_per_episode[episode_idx]
|
||||
assert actual_length == expected_length, (
|
||||
f"Episode {episode_idx} length {actual_length} != expected {expected_length}"
|
||||
)
|
||||
|
||||
cumulative_check = to_idx
|
||||
|
||||
# Final check: last episode should end at total frames
|
||||
expected_total_frames = sum(frames_per_episode)
|
||||
assert cumulative_check == expected_total_frames, (
|
||||
f"Final frame count {cumulative_check} != expected {expected_total_frames}"
|
||||
)
|
||||
|
||||
|
||||
def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that statistics are properly computed and stored for all features."""
|
||||
features = {
|
||||
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
|
||||
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
|
||||
}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create controlled data to verify statistics
|
||||
num_episodes = 2
|
||||
frames_per_episode = [10, 10]
|
||||
|
||||
# Use deterministic data for predictable statistics
|
||||
torch.manual_seed(42)
|
||||
for episode_idx in range(num_episodes):
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
|
||||
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
|
||||
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that statistics exist for all features
|
||||
assert loaded_dataset.meta.stats is not None, "No statistics found"
|
||||
|
||||
for feature_name in features.keys():
|
||||
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
|
||||
|
||||
feature_stats = loaded_dataset.meta.stats[feature_name]
|
||||
expected_stats = ["min", "max", "mean", "std", "count"]
|
||||
|
||||
for stat_key in expected_stats:
|
||||
assert stat_key in feature_stats, f"Missing '{stat_key}' statistic for '{feature_name}'"
|
||||
|
||||
stat_value = feature_stats[stat_key]
|
||||
# Basic sanity checks
|
||||
if stat_key == "count":
|
||||
assert stat_value == sum(frames_per_episode), f"Wrong count for '{feature_name}'"
|
||||
elif stat_key in ["min", "max", "mean", "std"]:
|
||||
# Check that statistics are reasonable (not NaN, proper shapes)
|
||||
if hasattr(stat_value, "shape"):
|
||||
expected_shape = features[feature_name]["shape"]
|
||||
assert stat_value.shape == expected_shape or len(stat_value) == expected_shape[0], (
|
||||
f"Wrong shape for {stat_key} of '{feature_name}'"
|
||||
)
|
||||
# Check no NaN values
|
||||
if hasattr(stat_value, "__iter__"):
|
||||
assert not any(np.isnan(v) for v in stat_value), f"NaN in {stat_key} for '{feature_name}'"
|
||||
else:
|
||||
assert not np.isnan(stat_value), f"NaN in {stat_key} for '{feature_name}'"
|
||||
|
||||
|
||||
def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test frame indices and episode transitions at episode boundaries."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
num_episodes = 3
|
||||
frames_per_episode = [7, 12, 5]
|
||||
|
||||
for episode_idx in range(num_episodes):
|
||||
for frame_idx in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Test episode boundaries
|
||||
cumulative = 0
|
||||
for ep_idx, ep_length in enumerate(frames_per_episode):
|
||||
if ep_idx > 0:
|
||||
# Check last frame of previous episode
|
||||
prev_frame = loaded_dataset[cumulative - 1]
|
||||
assert prev_frame["episode_index"].item() == ep_idx - 1
|
||||
|
||||
# Check first frame of current episode
|
||||
if cumulative < len(loaded_dataset):
|
||||
curr_frame = loaded_dataset[cumulative]
|
||||
assert curr_frame["episode_index"].item() == ep_idx
|
||||
|
||||
# Check frame_index within episode
|
||||
for i in range(ep_length):
|
||||
if cumulative + i < len(loaded_dataset):
|
||||
frame = loaded_dataset[cumulative + i]
|
||||
assert frame["frame_index"].item() == i, f"Frame {cumulative + i} has wrong frame_index"
|
||||
assert frame["episode_index"].item() == ep_idx, (
|
||||
f"Frame {cumulative + i} has wrong episode_index"
|
||||
)
|
||||
|
||||
cumulative += ep_length
|
||||
|
||||
|
||||
def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that tasks are properly indexed and retrievable."""
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Use multiple tasks, including repeated ones
|
||||
tasks = ["pick", "place", "pick", "navigate", "place"]
|
||||
unique_tasks = list(set(tasks)) # ["pick", "place", "navigate"]
|
||||
frames_per_episode = [5, 8, 3, 10, 6]
|
||||
|
||||
for episode_idx, task in enumerate(tasks):
|
||||
for _ in range(frames_per_episode[episode_idx]):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": task})
|
||||
dataset.save_episode()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that all unique tasks are in the tasks metadata
|
||||
stored_tasks = set(loaded_dataset.meta.tasks.index)
|
||||
assert stored_tasks == set(unique_tasks), f"Stored tasks {stored_tasks} != expected {set(unique_tasks)}"
|
||||
|
||||
# Check that task indices are consistent
|
||||
cumulative = 0
|
||||
for episode_idx, expected_task in enumerate(tasks):
|
||||
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
|
||||
assert episode_metadata["tasks"] == [expected_task]
|
||||
|
||||
# Check frames in this episode have correct task
|
||||
for i in range(frames_per_episode[episode_idx]):
|
||||
frame = loaded_dataset[cumulative + i]
|
||||
assert frame["task"] == expected_task, f"Frame {cumulative + i} has wrong task"
|
||||
|
||||
# Check task_index consistency
|
||||
expected_task_index = loaded_dataset.meta.get_task_index(expected_task)
|
||||
assert frame["task_index"].item() == expected_task_index
|
||||
|
||||
cumulative += frames_per_episode[episode_idx]
|
||||
|
||||
# Check total number of tasks
|
||||
assert loaded_dataset.meta.total_tasks == len(unique_tasks)
|
||||
|
||||
@@ -11,83 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from itertools import accumulate
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lengths = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
|
||||
"to": np.array(cumulative_lengths, dtype=np.int64),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(
|
||||
@@ -136,78 +68,6 @@ def delta_indices_factory():
|
||||
return _delta_indices
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx = np.array([0.0]), np.array([0])
|
||||
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
|
||||
result = check_timestamps_sync(
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
|
||||
@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
|
||||
assert sampler.indices == [1, 4, 5]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [1, 4, 5]
|
||||
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
|
||||
assert sampler.indices == [0, 3, 4]
|
||||
assert len(sampler) == 3
|
||||
assert list(sampler) == [0, 3, 4]
|
||||
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
|
||||
sampler = EpisodeAwareSampler(
|
||||
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
|
||||
)
|
||||
assert sampler.indices == [0, 1, 3, 4, 5]
|
||||
assert len(sampler) == 5
|
||||
assert list(sampler) == [0, 1, 3, 4, 5]
|
||||
@@ -80,11 +82,11 @@ def test_shuffle():
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert list(sampler) == [0, 1, 2, 3, 4, 5]
|
||||
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
|
||||
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
|
||||
assert sampler.indices == [0, 1, 2, 3, 4, 5]
|
||||
assert len(sampler) == 6
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
@@ -14,12 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
|
||||
from lerobot.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
unflatten_dict,
|
||||
)
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
|
||||
episode_data_index = calculate_episode_data_index(dataset)
|
||||
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
"min": 0,
|
||||
"max": 1,
|
||||
"mean": 2,
|
||||
"std": 3,
|
||||
},
|
||||
"action": {
|
||||
"min": 4,
|
||||
"max": 5,
|
||||
"mean": 6,
|
||||
"std": 7,
|
||||
},
|
||||
}
|
||||
|
||||
original_d = deepcopy(d)
|
||||
d = unflatten_dict(flatten_dict(d))
|
||||
|
||||
# test equality between nested dicts
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
4
tests/fixtures/constants.py
vendored
4
tests/fixtures/constants.py
vendored
@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
|
||||
304
tests/fixtures/dataset_factories.py
vendored
304
tests/fixtures/dataset_factories.py
vendored
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
@@ -19,19 +20,25 @@ from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
@@ -46,10 +53,9 @@ class LeRobotDatasetFactory(Protocol):
|
||||
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
|
||||
task_idx = tasks.loc[task].task_index.item()
|
||||
return task_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -62,15 +68,49 @@ def img_tensor_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray:
|
||||
if content is None:
|
||||
# Original random noise behavior
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
# Create image with text content using OpenCV
|
||||
import cv2
|
||||
|
||||
# Create white background
|
||||
img_array = np.ones((height, width, channels), dtype=np.uint8) * 255
|
||||
|
||||
# Font settings
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = max(0.5, height / 200) # Scale font with image size
|
||||
font_color = (0, 0, 0) # Black text
|
||||
thickness = max(1, int(height / 100))
|
||||
|
||||
# Get text size to center it
|
||||
text_size = cv2.getTextSize(content, font, font_scale, thickness)[0]
|
||||
text_x = (width - text_size[0]) // 2
|
||||
text_y = (height + text_size[1]) // 2
|
||||
|
||||
# Put text on image
|
||||
cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness)
|
||||
|
||||
# Handle single channel case
|
||||
if channels == 1:
|
||||
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
|
||||
img_array = img_array[:, :, np.newaxis]
|
||||
|
||||
# Convert to target dtype
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array = img_array.astype(dtype) / 255.0
|
||||
else:
|
||||
img_array = img_array.astype(dtype)
|
||||
|
||||
return img_array
|
||||
|
||||
return _create_img_array
|
||||
@@ -117,9 +157,10 @@ def info_factory(features_factory):
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
data_path: str = DEFAULT_DATA_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
@@ -133,8 +174,9 @@ def info_factory(features_factory):
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"data_files_size_in_mb": data_files_size_in_mb,
|
||||
"video_files_size_in_mb": video_files_size_in_mb,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
@@ -175,41 +217,26 @@ def stats_factory():
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_factory(stats_factory):
|
||||
def _create_episodes_stats(
|
||||
features: dict[str],
|
||||
total_episodes: int = 3,
|
||||
) -> dict:
|
||||
episodes_stats = {}
|
||||
for episode_index in range(total_episodes):
|
||||
episodes_stats[episode_index] = {
|
||||
"episode_index": episode_index,
|
||||
"stats": stats_factory(features),
|
||||
}
|
||||
return episodes_stats
|
||||
|
||||
return _create_episodes_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks = {}
|
||||
for task_index in range(total_tasks):
|
||||
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
|
||||
tasks[task_index] = task_dict
|
||||
return tasks
|
||||
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||
ids = list(range(total_tasks))
|
||||
tasks = [f"Perform action {i}." for i in ids]
|
||||
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
||||
return df
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
def episodes_factory(tasks_factory, stats_factory):
|
||||
def _create_episodes(
|
||||
features: dict[str],
|
||||
fps: int = DEFAULT_FPS,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
tasks: dict | None = None,
|
||||
video_keys: list[str] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
@@ -217,66 +244,142 @@ def episodes_factory(tasks_factory):
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
num_tasks_available = len(tasks)
|
||||
|
||||
if total_episodes < num_tasks_available and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
|
||||
num_tasks_available = len(tasks_list)
|
||||
# Create empty dictionaries with all keys
|
||||
d = {
|
||||
"episode_index": [],
|
||||
"meta/episodes/chunk_index": [],
|
||||
"meta/episodes/file_index": [],
|
||||
"data/chunk_index": [],
|
||||
"data/file_index": [],
|
||||
"dataset_from_index": [],
|
||||
"dataset_to_index": [],
|
||||
"tasks": [],
|
||||
"length": [],
|
||||
}
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"] = []
|
||||
d[f"videos/{video_key}/file_index"] = []
|
||||
d[f"videos/{video_key}/from_timestamp"] = []
|
||||
d[f"videos/{video_key}/to_timestamp"] = []
|
||||
|
||||
episodes = {}
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for stats_key in flatten_dict({"stats": stats_factory(features)}):
|
||||
d[stats_key] = []
|
||||
|
||||
num_frames = 0
|
||||
remaining_tasks = list(tasks.index)
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes[ep_idx] = {
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
d["episode_index"].append(ep_idx)
|
||||
# TODO(rcadene): remove heuristic of only one file
|
||||
d["meta/episodes/chunk_index"].append(0)
|
||||
d["meta/episodes/file_index"].append(0)
|
||||
d["data/chunk_index"].append(0)
|
||||
d["data/file_index"].append(0)
|
||||
d["dataset_from_index"].append(num_frames)
|
||||
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
|
||||
d["tasks"].append(episode_tasks)
|
||||
d["length"].append(lengths[ep_idx])
|
||||
|
||||
return episodes
|
||||
if video_keys is not None:
|
||||
for video_key in video_keys:
|
||||
d[f"videos/{video_key}/chunk_index"].append(0)
|
||||
d[f"videos/{video_key}/file_index"].append(0)
|
||||
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
|
||||
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
|
||||
|
||||
# Add stats columns like "stats/action/max"
|
||||
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
|
||||
d[stats_key].append(stats)
|
||||
|
||||
num_frames += lengths[ep_idx]
|
||||
|
||||
return Dataset.from_dict(d)
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def create_videos(info_factory, img_array_factory):
|
||||
def _create_video_directory(
|
||||
root: Path,
|
||||
info: dict | None = None,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
):
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
|
||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||
for key, ft in video_feats.items():
|
||||
# create and save images with identifiable content
|
||||
tmp_dir = root / "tmp_images"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
for frame_index in range(info["total_frames"]):
|
||||
content = f"{key}-{frame_index}"
|
||||
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
|
||||
pil_img = PIL.Image.fromarray(img)
|
||||
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||
pil_img.save(path)
|
||||
|
||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Use the global fps from info, not video-specific fps which might not exist
|
||||
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return _create_video_directory
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
if features is None:
|
||||
features = features_factory()
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(features, fps)
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes.values():
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
|
||||
# TODO(rcadene): assign the tasks of the episode per chunks of frames
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
@@ -286,8 +389,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
for _ in range(len(index_col))
|
||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}")
|
||||
for i in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
@@ -314,7 +417,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
@@ -324,29 +426,29 @@ def lerobot_dataset_metadata_factory(
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
@@ -366,7 +468,6 @@ def lerobot_dataset_metadata_factory(
|
||||
def lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
episodes_stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
@@ -380,50 +481,63 @@ def lerobot_dataset_factory(
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
use_videos: bool = True,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes_metadata: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
use_videos=use_videos,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
if not stats:
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
if episodes_metadata is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
|
||||
)
|
||||
|
||||
# Write data on disk
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
episodes=episodes_metadata,
|
||||
hf_dataset=hf_dataset,
|
||||
data_files_size_in_mb=data_files_size_in_mb,
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info=info,
|
||||
stats=stats,
|
||||
episodes_stats=episodes_stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
episodes=episodes_metadata,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
|
||||
231
tests/fixtures/files.py
vendored
231
tests/fixtures/files.py
vendored
@@ -11,137 +11,166 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
get_hf_dataset_size_in_mb,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
hf_dataset: Dataset,
|
||||
local_dir: Path,
|
||||
data_file_size_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
"""
|
||||
Writes a Hugging Face Dataset to one or more Parquet files in a structured directory format.
|
||||
|
||||
If the dataset size is within `DEFAULT_DATA_FILE_SIZE_IN_MB`, it's saved as a single file.
|
||||
Otherwise, the dataset is split into multiple smaller Parquet files, each not exceeding the size limit.
|
||||
The file and chunk indices are managed to organize the output files in a hierarchical structure,
|
||||
e.g., `data/chunk-000/file-000.parquet`, `data/chunk-000/file-001.parquet`, etc.
|
||||
This function ensures that episodes are not split across multiple files.
|
||||
|
||||
Args:
|
||||
hf_dataset (Dataset): The Hugging Face Dataset to be written to disk.
|
||||
local_dir (Path): The root directory where the dataset files will be stored.
|
||||
data_file_size_mb (float, optional): Maximal size for the parquet data file, in MB. Defaults to DEFAULT_DATA_FILE_SIZE_IN_MB.
|
||||
chunk_size (int, optional): Maximal number of files within a chunk folder before creating another one. Defaults to DEFAULT_CHUNK_SIZE.
|
||||
"""
|
||||
if data_file_size_mb is None:
|
||||
data_file_size_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
|
||||
if chunk_size is None:
|
||||
chunk_size = DEFAULT_CHUNK_SIZE
|
||||
|
||||
dataset_size_in_mb = get_hf_dataset_size_in_mb(hf_dataset)
|
||||
|
||||
if dataset_size_in_mb <= data_file_size_mb:
|
||||
# If the dataset is small enough, write it to a single file.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
hf_dataset.to_parquet(path)
|
||||
return
|
||||
|
||||
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
|
||||
episode_indices = np.array(hf_dataset["episode_index"])
|
||||
episode_boundaries = np.where(np.diff(episode_indices) != 0)[0] + 1
|
||||
episode_starts = np.concatenate(([0], episode_boundaries))
|
||||
episode_ends = np.concatenate((episode_boundaries, [len(hf_dataset)]))
|
||||
|
||||
num_episodes = len(episode_starts)
|
||||
current_episode_idx = 0
|
||||
chunk_idx, file_idx = 0, 0
|
||||
|
||||
while current_episode_idx < num_episodes:
|
||||
shard_start_row = episode_starts[current_episode_idx]
|
||||
shard_end_row = episode_ends[current_episode_idx]
|
||||
next_episode_to_try_idx = current_episode_idx + 1
|
||||
|
||||
while next_episode_to_try_idx < num_episodes:
|
||||
potential_shard_end_row = episode_ends[next_episode_to_try_idx]
|
||||
dataset_shard_candidate = hf_dataset.select(range(shard_start_row, potential_shard_end_row))
|
||||
shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard_candidate)
|
||||
|
||||
if shard_size_mb > data_file_size_mb:
|
||||
break
|
||||
else:
|
||||
shard_end_row = potential_shard_end_row
|
||||
next_episode_to_try_idx += 1
|
||||
|
||||
dataset_shard = hf_dataset.select(range(shard_start_row, shard_end_row))
|
||||
|
||||
if (
|
||||
shard_start_row == episode_starts[current_episode_idx]
|
||||
and shard_end_row == episode_ends[current_episode_idx]
|
||||
):
|
||||
shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard)
|
||||
if shard_size_mb > data_file_size_mb:
|
||||
logging.warning(
|
||||
f"Episode with index {hf_dataset[shard_start_row.item()]['episode_index']} has size {shard_size_mb:.2f}MB, "
|
||||
f"which is larger than data_file_size_mb ({data_file_size_mb}MB). "
|
||||
"Writing it to a separate shard anyway to preserve episode integrity."
|
||||
)
|
||||
|
||||
# Define the path for the current shard and ensure the directory exists.
|
||||
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the shard to a Parquet file.
|
||||
dataset_shard.to_parquet(path)
|
||||
|
||||
# Update chunk and file indices for the next iteration.
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
current_episode_idx = next_episode_to_try_idx
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
def create_info(info_factory):
|
||||
def _create_info(dir: Path, info: dict | None = None):
|
||||
if info is None:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_info(info, dir)
|
||||
|
||||
return _create_info_json_file
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
def create_stats(stats_factory):
|
||||
def _create_stats(dir: Path, stats: dict | None = None):
|
||||
if stats is None:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
write_stats(stats, dir)
|
||||
|
||||
return _create_stats_json_file
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_stats_path(episodes_stats_factory):
|
||||
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory()
|
||||
fpath = dir / EPISODES_STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes_stats.values())
|
||||
return fpath
|
||||
|
||||
return _create_episodes_stats_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
def create_tasks(tasks_factory):
|
||||
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
|
||||
if tasks is None:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks.values())
|
||||
return fpath
|
||||
write_tasks(tasks, dir)
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
def create_episodes(episodes_factory):
|
||||
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
|
||||
if episodes is None:
|
||||
# TODO(rcadene): add features, fps as arguments
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes.values())
|
||||
return fpath
|
||||
write_episodes(episodes, dir)
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
def create_hf_dataset(hf_dataset_factory):
|
||||
def _create_hf_dataset(
|
||||
dir: Path,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_file_size_in_mb: float | None = None,
|
||||
chunk_size: int | None = None,
|
||||
):
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
write_hf_dataset(hf_dataset, dir, data_file_size_in_mb, chunk_size)
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return fpath
|
||||
|
||||
return _create_single_episode_parquet
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
total_episodes = info["total_episodes"]
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
|
||||
return _create_multi_episode_parquet
|
||||
return _create_hf_dataset
|
||||
|
||||
134
tests/fixtures/hub.py
vendored
134
tests/fixtures/hub.py
vendored
@@ -14,15 +14,19 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
EPISODES_STATS_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_TASKS_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
)
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
@@ -30,17 +34,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
create_info,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
episodes_stats_factory,
|
||||
episodes_stats_path,
|
||||
create_stats,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
create_tasks,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
create_episodes,
|
||||
hf_dataset_factory,
|
||||
create_hf_dataset,
|
||||
create_videos,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
@@ -50,82 +53,93 @@ def mock_snapshot_download_factory(
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
episodes_stats: list[dict] | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
episodes: datasets.Dataset | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
if info is None:
|
||||
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not episodes_stats:
|
||||
episodes_stats = episodes_stats_factory(
|
||||
features=info["features"], total_episodes=info["total_episodes"]
|
||||
)
|
||||
if not tasks:
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
if episodes is None:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
)
|
||||
if not hf_dataset:
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
repo_id: str, # TODO(rcadene): repo_id should be used no?
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
if local_dir is None:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
all_files = [
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
# TODO(rcadene): remove naive chunk 0 file 0 ?
|
||||
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
|
||||
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
|
||||
]
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes.values():
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
|
||||
for key in video_keys:
|
||||
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
request_info = False
|
||||
request_tasks = False
|
||||
request_episodes = False
|
||||
request_stats = False
|
||||
request_data = False
|
||||
request_videos = False
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == EPISODES_STATS_PATH:
|
||||
_ = episodes_stats_path(local_dir, episodes_stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
if rel_path.startswith("meta/info.json"):
|
||||
request_info = True
|
||||
elif rel_path.startswith("meta/stats"):
|
||||
request_stats = True
|
||||
elif rel_path.startswith("meta/tasks"):
|
||||
request_tasks = True
|
||||
elif rel_path.startswith("meta/episodes"):
|
||||
request_episodes = True
|
||||
elif rel_path.startswith("data/"):
|
||||
request_data = True
|
||||
elif rel_path.startswith("videos/"):
|
||||
request_videos = True
|
||||
else:
|
||||
pass
|
||||
raise ValueError(f"{rel_path} not supported.")
|
||||
|
||||
if request_info:
|
||||
create_info(local_dir, info)
|
||||
if request_stats:
|
||||
create_stats(local_dir, stats)
|
||||
if request_tasks:
|
||||
create_tasks(local_dir, tasks)
|
||||
if request_episodes:
|
||||
create_episodes(local_dir, episodes)
|
||||
if request_data:
|
||||
create_hf_dataset(local_dir, hf_dataset, data_files_size_in_mb, chunks_size)
|
||||
if request_videos:
|
||||
create_videos(root=local_dir, info=info)
|
||||
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
@@ -71,7 +71,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
|
||||
},
|
||||
}
|
||||
info = info_factory(
|
||||
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
|
||||
total_episodes=1,
|
||||
total_frames=1,
|
||||
total_tasks=1,
|
||||
camera_features=camera_features,
|
||||
motor_features=motor_features,
|
||||
)
|
||||
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
|
||||
return ds_meta
|
||||
@@ -140,7 +144,6 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
||||
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
|
||||
and for now we add tests as we see fit.
|
||||
"""
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from lerobot.calibrate import CalibrateConfig, calibrate
|
||||
from lerobot.record import DatasetRecordConfig, RecordConfig, record
|
||||
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
|
||||
@@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path):
|
||||
assert dataset.meta.total_tasks == 1
|
||||
|
||||
cfg.resume = True
|
||||
dataset = record(cfg)
|
||||
# Mock the revision to prevent Hub calls during resume
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "record")
|
||||
dataset = record(cfg)
|
||||
|
||||
assert dataset.meta.total_episodes == dataset.num_episodes == 2
|
||||
assert dataset.meta.total_frames == dataset.num_frames == 6
|
||||
@@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path):
|
||||
)
|
||||
|
||||
record(record_cfg)
|
||||
replay(replay_cfg)
|
||||
|
||||
# Mock the revision to prevent Hub calls during replay
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
|
||||
replay(replay_cfg)
|
||||
|
||||
@@ -384,7 +384,7 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
elif feature == "next.done":
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
elif feature == "observation.image":
|
||||
# Tenssor -> numpy is not precise, so we have some diff there
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||
elif feature == "observation.state":
|
||||
|
||||
Reference in New Issue
Block a user