Dataset v3 (#1412)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
Co-authored-by: Tavish <tavish9.chen@gmail.com>
Co-authored-by: fracapuano <francesco.capuano@huggingface.co>
Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
Michel Aractingi
2025-09-15 09:53:30 +02:00
committed by GitHub
parent d602e8169c
commit f55c6e89f0
50 changed files with 4642 additions and 4092 deletions

View File

@@ -47,38 +47,22 @@ def save_dataset_to_safetensors(output_dir, repo_id="lerobot/pusht"):
)
# save 2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
i = dataset.meta.episodes["dataset_from_index"][0]
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(dataset.meta.episodes["dataset_to_index"][0] - dataset.meta.episodes["dataset_from_index"][0]) / 2
)
save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
save_file(dataset[i + 1], repo_dir / f"frame_{i + 1}.safetensors")
# save 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
i = dataset.meta.episodes["dataset_to_index"][0]
save_file(dataset[i - 2], repo_dir / f"frame_{i - 2}.safetensors")
save_file(dataset[i - 1], repo_dir / f"frame_{i - 1}.safetensors")
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # save 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# save_file(dataset[i], repo_dir / f"frame_{i}.safetensors")
# save_file(dataset[i + 1], repo_dir / f"frame_{i+1}.safetensors")
# # save 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
# # save 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# save_file(dataset[i - 2], repo_dir / f"frame_{i-2}.safetensors")
# save_file(dataset[i - 1], repo_dir / f"frame_{i-1}.safetensors")
if __name__ == "__main__":
for dataset in [

View File

@@ -0,0 +1,292 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
import torch
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from tests.fixtures.constants import DUMMY_REPO_ID
def assert_episode_and_frame_counts(aggr_ds, expected_episodes, expected_frames):
"""Test that total number of episodes and frames are correctly aggregated."""
assert aggr_ds.num_episodes == expected_episodes, (
f"Expected {expected_episodes} episodes, got {aggr_ds.num_episodes}"
)
assert aggr_ds.num_frames == expected_frames, (
f"Expected {expected_frames} frames, got {aggr_ds.num_frames}"
)
def assert_dataset_content_integrity(aggr_ds, ds_0, ds_1):
"""Test that the content of both datasets is preserved correctly in the aggregated dataset."""
keys_to_ignore = ["episode_index", "index", "timestamp"]
# Test first part of dataset corresponds to ds_0, check first item (index 0) matches ds_0[0]
aggr_first_item = aggr_ds[0]
ds_0_first_item = ds_0[0]
# Compare all keys except episode_index and index which should be updated
for key in ds_0_first_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_first_item[key]) and torch.is_tensor(ds_0_first_item[key]):
assert torch.allclose(aggr_first_item[key], ds_0_first_item[key], atol=1e-6), (
f"First item key '{key}' doesn't match between aggregated and ds_0"
)
else:
assert aggr_first_item[key] == ds_0_first_item[key], (
f"First item key '{key}' doesn't match between aggregated and ds_0"
)
# Check last item of ds_0 part (index len(ds_0)-1) matches ds_0[-1]
aggr_ds_0_last_item = aggr_ds[len(ds_0) - 1]
ds_0_last_item = ds_0[-1]
for key in ds_0_last_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_0_last_item[key]) and torch.is_tensor(ds_0_last_item[key]):
assert torch.allclose(aggr_ds_0_last_item[key], ds_0_last_item[key], atol=1e-6), (
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
)
else:
assert aggr_ds_0_last_item[key] == ds_0_last_item[key], (
f"Last ds_0 item key '{key}' doesn't match between aggregated and ds_0"
)
# Test second part of dataset corresponds to ds_1
# Check first item of ds_1 part (index len(ds_0)) matches ds_1[0]
aggr_ds_1_first_item = aggr_ds[len(ds_0)]
ds_1_first_item = ds_1[0]
for key in ds_1_first_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_ds_1_first_item[key]) and torch.is_tensor(ds_1_first_item[key]):
assert torch.allclose(aggr_ds_1_first_item[key], ds_1_first_item[key], atol=1e-6), (
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
)
else:
assert aggr_ds_1_first_item[key] == ds_1_first_item[key], (
f"First ds_1 item key '{key}' doesn't match between aggregated and ds_1"
)
# Check last item matches ds_1[-1]
aggr_last_item = aggr_ds[-1]
ds_1_last_item = ds_1[-1]
for key in ds_1_last_item:
if key not in keys_to_ignore:
# Handle both tensor and non-tensor data
if torch.is_tensor(aggr_last_item[key]) and torch.is_tensor(ds_1_last_item[key]):
assert torch.allclose(aggr_last_item[key], ds_1_last_item[key], atol=1e-6), (
f"Last item key '{key}' doesn't match between aggregated and ds_1"
)
else:
assert aggr_last_item[key] == ds_1_last_item[key], (
f"Last item key '{key}' doesn't match between aggregated and ds_1"
)
def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
"""Test that metadata is correctly aggregated."""
# Test basic info
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
"Robot type should be the same"
)
# Test features are the same
assert aggr_ds.features == ds_0.features == ds_1.features, "Features should be the same"
# Test tasks aggregation
expected_tasks = set(ds_0.meta.tasks.index) | set(ds_1.meta.tasks.index)
actual_tasks = set(aggr_ds.meta.tasks.index)
assert actual_tasks == expected_tasks, f"Expected tasks {expected_tasks}, got {actual_tasks}"
def assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1):
"""Test that episode indices are correctly updated after aggregation."""
# ds_0 episodes should have episode_index 0 to ds_0.num_episodes-1
for i in range(len(ds_0)):
assert aggr_ds[i]["episode_index"] < ds_0.num_episodes, (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be < {ds_0.num_episodes}"
)
def ds1_episodes_condition(ep_idx):
return (ep_idx >= ds_0.num_episodes) and (ep_idx < ds_0.num_episodes + ds_1.num_episodes)
# ds_1 episodes should have episode_index ds_0.num_episodes to total_episodes-1
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
expected_min_episode_idx = ds_0.num_episodes
assert ds1_episodes_condition(aggr_ds[i]["episode_index"]), (
f"Episode index {aggr_ds[i]['episode_index']} at position {i} should be >= {expected_min_episode_idx}"
)
def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
"""Test that video frames are correctly preserved and frame indices are updated."""
def visual_frames_equal(frame1, frame2):
return torch.allclose(frame1, frame2)
video_keys = list(
filter(
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
aggr_ds.meta.info["features"].keys(),
)
)
# Test the section corresponding to the first dataset (ds_0)
for i in range(len(ds_0)):
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in video_keys:
assert visual_frames_equal(aggr_ds[i][key], ds_0[i][key]), (
f"Visual frames at position {i} should be equal between aggregated and ds_0"
)
# Test the section corresponding to the second dataset (ds_1)
for i in range(len(ds_0), len(ds_0) + len(ds_1)):
# The frame index in the aggregated dataset should also match its position.
assert aggr_ds[i]["index"] == i, (
f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}"
)
for key in video_keys:
assert visual_frames_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), (
f"Visual frames at position {i} should be equal between aggregated and ds_1"
)
def assert_dataset_iteration_works(aggr_ds):
"""Test that we can iterate through the entire dataset without errors."""
for _ in aggr_ds:
pass
def test_aggregate_datasets(tmp_path, lerobot_dataset_factory):
"""Test basic aggregation functionality with standard parameters."""
ds_0_num_frames = 400
ds_1_num_frames = 800
ds_0_num_episodes = 10
ds_1_num_episodes = 25
# Create two datasets with different number of frames and episodes
ds_0 = lerobot_dataset_factory(
root=tmp_path / "test_0",
repo_id=f"{DUMMY_REPO_ID}_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "test_1",
repo_id=f"{DUMMY_REPO_ID}_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
)
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_aggr",
aggr_root=tmp_path / "test_aggr",
)
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "test_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_aggr", root=tmp_path / "test_aggr")
# Run all assertion functions
expected_total_episodes = ds_0.num_episodes + ds_1.num_episodes
expected_total_frames = ds_0.num_frames + ds_1.num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
def test_aggregate_with_low_threshold(tmp_path, lerobot_dataset_factory):
"""Test aggregation with small file size limits to force file rotation/sharding."""
ds_0_num_episodes = ds_1_num_episodes = 10
ds_0_num_frames = ds_1_num_frames = 400
ds_0 = lerobot_dataset_factory(
root=tmp_path / "small_0",
repo_id=f"{DUMMY_REPO_ID}_small_0",
total_episodes=ds_0_num_episodes,
total_frames=ds_0_num_frames,
)
ds_1 = lerobot_dataset_factory(
root=tmp_path / "small_1",
repo_id=f"{DUMMY_REPO_ID}_small_1",
total_episodes=ds_1_num_episodes,
total_frames=ds_1_num_frames,
)
# Use the new configurable parameters to force file rotation
aggregate_datasets(
repo_ids=[ds_0.repo_id, ds_1.repo_id],
roots=[ds_0.root, ds_1.root],
aggr_repo_id=f"{DUMMY_REPO_ID}_small_aggr",
aggr_root=tmp_path / "small_aggr",
# Tiny file size to trigger new file instantiation
data_files_size_in_mb=0.01,
video_files_size_in_mb=0.1,
)
# Mock the revision to prevent Hub calls during dataset loading
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "small_aggr")
aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_small_aggr", root=tmp_path / "small_aggr")
# Verify aggregation worked correctly despite file size constraints
expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes
expected_total_frames = ds_0_num_frames + ds_1_num_frames
assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames)
assert_dataset_content_integrity(aggr_ds, ds_0, ds_1)
assert_metadata_consistency(aggr_ds, ds_0, ds_1)
assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1)
assert_video_frames_integrity(aggr_ds, ds_0, ds_1)
assert_dataset_iteration_works(aggr_ds)
# Check that multiple files were actually created due to small size limits
data_dir = tmp_path / "small_aggr" / "data"
video_dir = tmp_path / "small_aggr" / "videos"
if data_dir.exists():
parquet_files = list(data_dir.rglob("*.parquet"))
assert len(parquet_files) > 1, "Small file size limits should create multiple parquet files"
if video_dir.exists():
video_files = list(video_dir.rglob("*.mp4"))
assert len(video_files) > 1, "Small file size limits should create multiple video files"

View File

@@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import re
from copy import deepcopy
from itertools import chain
from pathlib import Path
@@ -37,13 +35,19 @@ from lerobot.datasets.lerobot_dataset import (
MultiLeRobotDataset,
)
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
create_branch,
flatten_dict,
unflatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
hw_to_dataset_features,
)
from lerobot.envs.factory import make_env_config
from lerobot.policies.factory import make_policy_config
from lerobot.robots import make_robot_from_config
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
from tests.mocks.mock_robot import MockRobotConfig
from tests.utils import require_x86_64_kernel
@@ -69,12 +73,17 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
objects have the same sets of attributes defined.
"""
# Instantiate both ways
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
robot = make_robot_from_config(MockRobotConfig())
action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
dataset_features = {**action_features, **obs_features}
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=features, root=root_create)
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create
)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
@@ -99,13 +108,41 @@ def test_dataset_initialization(tmp_path, lerobot_dataset_factory):
assert dataset.num_frames == len(dataset)
# TODO(rcadene, aliberts): do not run LeRobotDataset.create, instead refactor LeRobotDatasetMetadata.create
# and test the small resulting function that validates the features
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.constants import HF_LEROBOT_HOME
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
)
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
):
dataset.add_frame({"state": torch.randn(1)})
def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
):
dataset.add_frame({"wrong_feature": torch.randn(1)}, task="Dummy task")
dataset.add_frame({"task": "Dummy task"})
def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
@@ -114,7 +151,7 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
):
dataset.add_frame({"state": torch.randn(1), "extra": "dummy_extra"}, task="Dummy task")
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
@@ -123,7 +160,7 @@ def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -133,7 +170,7 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
):
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_factory):
@@ -145,7 +182,7 @@ def test_add_frame_wrong_shape_python_float(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'float'>' provided instead.\n"
),
):
dataset.add_frame({"state": 1.0}, task="Dummy task")
dataset.add_frame({"state": 1.0, "task": "Dummy task"})
def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -155,7 +192,7 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
):
dataset.add_frame({"state": torch.tensor(1.0)}, task="Dummy task")
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_factory):
@@ -167,13 +204,13 @@ def test_add_frame_wrong_shape_numpy_ndim_0(tmp_path, empty_lerobot_dataset_fact
"The feature 'state' is not a 'np.ndarray'. Expected type is 'float32', but type '<class 'numpy.float32'>' provided instead.\n"
),
):
dataset.add_frame({"state": np.float32(1.0)}, task="Dummy task")
dataset.add_frame({"state": np.float32(1.0), "task": "Dummy task"})
def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(1)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
dataset.save_episode()
assert len(dataset) == 1
@@ -185,7 +222,7 @@ def test_add_frame(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(2), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2])
@@ -194,7 +231,7 @@ def test_add_frame_state_1d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(2, 4), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4])
@@ -203,7 +240,7 @@ def test_add_frame_state_2d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(2, 4, 3), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3])
@@ -212,7 +249,7 @@ def test_add_frame_state_3d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(2, 4, 3, 5), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5])
@@ -221,7 +258,7 @@ def test_add_frame_state_4d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (2, 4, 3, 5, 1), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1)}, task="Dummy task")
dataset.add_frame({"state": torch.randn(2, 4, 3, 5, 1), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].shape == torch.Size([2, 4, 3, 5, 1])
@@ -230,7 +267,7 @@ def test_add_frame_state_5d(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([1], dtype=np.float32)}, task="Dummy task")
dataset.add_frame({"state": np.array([1], dtype=np.float32), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["state"].ndim == 0
@@ -239,7 +276,7 @@ def test_add_frame_state_numpy(tmp_path, empty_lerobot_dataset_factory):
def test_add_frame_string(tmp_path, empty_lerobot_dataset_factory):
features = {"caption": {"dtype": "string", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"caption": "Dummy caption"}, task="Dummy task")
dataset.add_frame({"caption": "Dummy caption", "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["caption"] == "Dummy caption"
@@ -254,7 +291,7 @@ def test_add_frame_image_wrong_shape(image_dataset):
),
):
c, h, w = DUMMY_CHW
dataset.add_frame({"image": torch.randn(c, w, h)}, task="Dummy task")
dataset.add_frame({"image": torch.randn(c, w, h), "task": "Dummy task"})
def test_add_frame_image_wrong_range(image_dataset):
@@ -267,14 +304,14 @@ def test_add_frame_image_wrong_range(image_dataset):
Hence the image won't be saved on disk and save_episode will raise `FileNotFoundError`.
"""
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255}, task="Dummy task")
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW) * 255, "task": "Dummy task"})
with pytest.raises(FileNotFoundError):
dataset.save_episode()
def test_add_frame_image(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW)}, task="Dummy task")
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -282,7 +319,7 @@ def test_add_frame_image(image_dataset):
def test_add_frame_image_h_w_c(image_dataset):
dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC)}, task="Dummy task")
dataset.add_frame({"image": np.random.rand(*DUMMY_HWC), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -291,7 +328,7 @@ def test_add_frame_image_h_w_c(image_dataset):
def test_add_frame_image_uint8(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": image}, task="Dummy task")
dataset.add_frame({"image": image, "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -300,7 +337,7 @@ def test_add_frame_image_uint8(image_dataset):
def test_add_frame_image_pil(image_dataset):
dataset = image_dataset
image = np.random.randint(0, 256, DUMMY_HWC, dtype=np.uint8)
dataset.add_frame({"image": Image.fromarray(image)}, task="Dummy task")
dataset.add_frame({"image": Image.fromarray(image), "task": "Dummy task"})
dataset.save_episode()
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@@ -319,6 +356,13 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
# - [ ] test push_to_hub
# - [ ] test smaller methods
# TODO(rcadene):
# - [ ] fix code so that old test_factory + backward pass
# - [ ] write new unit tests to test save_episode + getitem
# - [ ] save_episode : case where new dataset, concatenate same file, write new file (meta/episodes, data, videos)
# - [ ]
# - [ ] remove old tests
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
@@ -338,9 +382,8 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=repo_id, episodes=[0]),
env=make_env_config(env_name),
policy=make_policy_config(policy_name, push_to_hub=False),
policy=make_policy_config(policy_name),
)
cfg.validate()
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
@@ -427,30 +470,6 @@ def test_multidataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.parametrize(
"repo_id",
[
@@ -497,38 +516,22 @@ def test_backward_compatibility(repo_id):
)
# test2 first frames of first episode
i = dataset.episode_data_index["from"][0].item()
i = dataset.meta.episodes[0]["dataset_from_index"]
load_and_compare(i)
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(dataset.meta.episodes[0]["dataset_to_index"] - dataset.meta.episodes[0]["dataset_from_index"]) / 2
)
load_and_compare(i)
load_and_compare(i + 1)
# test 2 last frames of first episode
i = dataset.episode_data_index["to"][0].item()
i = dataset.meta.episodes[0]["dataset_to_index"]
load_and_compare(i - 2)
load_and_compare(i - 1)
# TODO(rcadene): Enable testing on second and last episode
# We currently cant because our test dataset only contains the first episode
# # test 2 first frames of second episode
# i = dataset.episode_data_index["from"][1].item()
# load_and_compare(i)
# load_and_compare(i + 1)
# # test 2 last frames of second episode
# i = dataset.episode_data_index["to"][1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
# # test 2 last frames of last episode
# i = dataset.episode_data_index["to"][-1].item()
# load_and_compare(i - 2)
# load_and_compare(i - 1)
@pytest.mark.skip("Requires internet access")
def test_create_branch():
@@ -556,18 +559,499 @@ def test_create_branch():
api.delete_repo(repo_id, repo_type=repo_type)
def test_dataset_feature_with_forward_slash_raises_error():
# make sure dir does not exist
from lerobot.constants import HF_LEROBOT_HOME
def test_check_cached_episodes_sufficient(tmp_path, lerobot_dataset_factory):
"""Test the _check_cached_episodes_sufficient method of LeRobotDataset."""
# Create a dataset with 5 episodes (0-4)
dataset = lerobot_dataset_factory(
root=tmp_path / "test",
total_episodes=5,
total_frames=200,
use_videos=False,
)
dataset_dir = HF_LEROBOT_HOME / "lerobot/test/with/slash"
# make sure does not exist
if dataset_dir.exists():
dataset_dir.rmdir()
# Test hf_dataset is None
dataset.hf_dataset = None
assert dataset._check_cached_episodes_sufficient() is False
with pytest.raises(ValueError):
LeRobotDataset.create(
repo_id="lerobot/test/with/slash",
fps=30,
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
# Test hf_dataset is empty
import datasets
empty_features = get_hf_features_from_features(dataset.features)
dataset.hf_dataset = datasets.Dataset.from_dict(
{key: [] for key in empty_features}, features=empty_features
)
dataset.hf_dataset.set_transform(hf_transform_to_torch)
assert dataset._check_cached_episodes_sufficient() is False
# Restore the original dataset for remaining tests
dataset.hf_dataset = dataset.load_hf_dataset()
# Test all episodes requested (self.episodes = None) and all are available
dataset.episodes = None
assert dataset._check_cached_episodes_sufficient() is True
# Test specific episodes requested that are all available
dataset.episodes = [0, 2, 4]
assert dataset._check_cached_episodes_sufficient() is True
# Test request episodes that don't exist in the cached dataset
# Create a dataset with only episodes 0, 1, 2
limited_dataset = lerobot_dataset_factory(
root=tmp_path / "limited",
total_episodes=3,
total_frames=120,
use_videos=False,
)
# Request episodes that include non-existent ones
limited_dataset.episodes = [0, 1, 2, 3, 4]
assert limited_dataset._check_cached_episodes_sufficient() is False
# Test create a dataset with sparse episodes (e.g., only episodes 0, 2, 4)
# First create the full dataset structure
sparse_dataset = lerobot_dataset_factory(
root=tmp_path / "sparse",
total_episodes=5,
total_frames=200,
use_videos=False,
)
# Manually filter hf_dataset to only include episodes 0, 2, 4
episode_indices = sparse_dataset.hf_dataset["episode_index"]
mask = torch.zeros(len(episode_indices), dtype=torch.bool)
for ep in [0, 2, 4]:
mask |= torch.tensor(episode_indices) == ep
# Create a filtered dataset
filtered_data = {}
# Find image keys by checking features
image_keys = [key for key, ft in sparse_dataset.features.items() if ft.get("dtype") == "image"]
for key in sparse_dataset.hf_dataset.column_names:
values = sparse_dataset.hf_dataset[key]
# Filter values based on mask
filtered_values = [val for i, val in enumerate(values) if mask[i]]
# Convert float32 image tensors back to uint8 numpy arrays for HuggingFace dataset
if key in image_keys and len(filtered_values) > 0:
# Convert torch tensors (float32, [0, 1], CHW) back to numpy arrays (uint8, [0, 255], HWC)
filtered_values = [
(val.permute(1, 2, 0).numpy() * 255).astype(np.uint8) for val in filtered_values
]
filtered_data[key] = filtered_values
sparse_dataset.hf_dataset = datasets.Dataset.from_dict(
filtered_data, features=get_hf_features_from_features(sparse_dataset.features)
)
sparse_dataset.hf_dataset.set_transform(hf_transform_to_torch)
# Test requesting all episodes when only some are cached
sparse_dataset.episodes = None
assert sparse_dataset._check_cached_episodes_sufficient() is False
# Test requesting only the available episodes
sparse_dataset.episodes = [0, 2, 4]
assert sparse_dataset._check_cached_episodes_sufficient() is True
# Test requesting a mix of available and unavailable episodes
sparse_dataset.episodes = [0, 1, 2]
assert sparse_dataset._check_cached_episodes_sufficient() is False
def test_update_chunk_settings(tmp_path, empty_lerobot_dataset_factory):
"""Test the update_chunk_settings functionality for both LeRobotDataset and LeRobotDatasetMetadata."""
features = {
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
},
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow", "wrist_1", "wrist_2", "wrist_3"],
},
}
# Create dataset with default chunk settings
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
# Test initial default values
initial_settings = dataset.meta.get_chunk_settings()
assert initial_settings["chunks_size"] == DEFAULT_CHUNK_SIZE
assert initial_settings["data_files_size_in_mb"] == DEFAULT_DATA_FILE_SIZE_IN_MB
assert initial_settings["video_files_size_in_mb"] == DEFAULT_VIDEO_FILE_SIZE_IN_MB
# Test updating all settings at once
new_chunks_size = 2000
new_data_size = 200
new_video_size = 1000
dataset.meta.update_chunk_settings(
chunks_size=new_chunks_size,
data_files_size_in_mb=new_data_size,
video_files_size_in_mb=new_video_size,
)
# Verify settings were updated
updated_settings = dataset.meta.get_chunk_settings()
assert updated_settings["chunks_size"] == new_chunks_size
assert updated_settings["data_files_size_in_mb"] == new_data_size
assert updated_settings["video_files_size_in_mb"] == new_video_size
# Test updating individual settings
dataset.meta.update_chunk_settings(chunks_size=1500)
settings_after_partial = dataset.meta.get_chunk_settings()
assert settings_after_partial["chunks_size"] == 1500
assert settings_after_partial["data_files_size_in_mb"] == new_data_size
assert settings_after_partial["video_files_size_in_mb"] == new_video_size
# Test updating only data file size
dataset.meta.update_chunk_settings(data_files_size_in_mb=150)
settings_after_data = dataset.meta.get_chunk_settings()
assert settings_after_data["chunks_size"] == 1500
assert settings_after_data["data_files_size_in_mb"] == 150
assert settings_after_data["video_files_size_in_mb"] == new_video_size
# Test updating only video file size
dataset.meta.update_chunk_settings(video_files_size_in_mb=800)
settings_after_video = dataset.meta.get_chunk_settings()
assert settings_after_video["chunks_size"] == 1500
assert settings_after_video["data_files_size_in_mb"] == 150
assert settings_after_video["video_files_size_in_mb"] == 800
# Test that settings persist in the info file
info_path = dataset.root / "meta" / "info.json"
assert info_path.exists()
# Verify the underlying metadata properties
assert dataset.meta.chunks_size == 1500
assert dataset.meta.data_files_size_in_mb == 150
assert dataset.meta.video_files_size_in_mb == 800
# Test error handling for invalid values
with pytest.raises(ValueError, match="chunks_size must be positive"):
dataset.meta.update_chunk_settings(chunks_size=0)
with pytest.raises(ValueError, match="chunks_size must be positive"):
dataset.meta.update_chunk_settings(chunks_size=-100)
with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"):
dataset.meta.update_chunk_settings(data_files_size_in_mb=0)
with pytest.raises(ValueError, match="data_files_size_in_mb must be positive"):
dataset.meta.update_chunk_settings(data_files_size_in_mb=-50)
with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"):
dataset.meta.update_chunk_settings(video_files_size_in_mb=0)
with pytest.raises(ValueError, match="video_files_size_in_mb must be positive"):
dataset.meta.update_chunk_settings(video_files_size_in_mb=-200)
# Test calling with None values (should not change anything)
settings_before_none = dataset.meta.get_chunk_settings()
dataset.meta.update_chunk_settings(
chunks_size=None, data_files_size_in_mb=None, video_files_size_in_mb=None
)
settings_after_none = dataset.meta.get_chunk_settings()
assert settings_before_none == settings_after_none
# Test metadata direct access
meta_settings = dataset.meta.get_chunk_settings()
assert meta_settings == dataset.meta.get_chunk_settings()
# Test updating via metadata directly
dataset.meta.update_chunk_settings(chunks_size=3000)
assert dataset.meta.get_chunk_settings()["chunks_size"] == 3000
def test_update_chunk_settings_video_dataset(tmp_path):
"""Test update_chunk_settings with a video dataset to ensure video-specific logic works."""
features = {
"observation.images.cam": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channels"],
},
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
}
# Create video dataset
dataset = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, features=features, root=tmp_path / "video_test", use_videos=True
)
# Test that video-specific settings work
original_video_size = dataset.meta.get_chunk_settings()["video_files_size_in_mb"]
new_video_size = original_video_size * 2
dataset.meta.update_chunk_settings(video_files_size_in_mb=new_video_size)
assert dataset.meta.get_chunk_settings()["video_files_size_in_mb"] == new_video_size
assert dataset.meta.video_files_size_in_mb == new_video_size
def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
"""Test that all frames have correct episode indices across multiple episodes."""
features = {"state": {"dtype": "float32", "shape": (2,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
# Create 3 episodes with different lengths
num_episodes = 3
frames_per_episode = [10, 15, 8]
for episode_idx in range(num_episodes):
for _ in range(frames_per_episode[episode_idx]):
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
dataset.save_episode()
# Load the dataset and check episode indices
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check specific frames across episode boundaries
cumulative = 0
for ep_idx, ep_length in enumerate(frames_per_episode):
# Check start, middle, and end of each episode
start_frame = cumulative
middle_frame = cumulative + ep_length // 2
end_frame = cumulative + ep_length - 1
for frame_idx in [start_frame, middle_frame, end_frame]:
frame_data = loaded_dataset[frame_idx]
actual_ep_idx = frame_data["episode_index"].item()
assert actual_ep_idx == ep_idx, (
f"Frame {frame_idx} has episode_index {actual_ep_idx}, should be {ep_idx}"
)
cumulative += ep_length
# Check episode index distribution
all_episode_indices = [loaded_dataset[i]["episode_index"].item() for i in range(len(loaded_dataset))]
from collections import Counter
distribution = Counter(all_episode_indices)
expected_dist = {i: frames_per_episode[i] for i in range(num_episodes)}
assert dict(distribution) == expected_dist, (
f"Episode distribution {dict(distribution)} != expected {expected_dist}"
)
def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_factory):
"""Test episode metadata consistency across multiple episodes."""
features = {
"state": {"dtype": "float32", "shape": (3,), "names": ["x", "y", "z"]},
"action": {"dtype": "float32", "shape": (2,), "names": ["v", "w"]},
}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
num_episodes = 4
frames_per_episode = [20, 35, 10, 25]
tasks = ["pick", "place", "pick", "place"]
for episode_idx in range(num_episodes):
for _ in range(frames_per_episode[episode_idx]):
dataset.add_frame({"state": torch.randn(3), "action": torch.randn(2), "task": tasks[episode_idx]})
dataset.save_episode()
# Load and validate episode metadata
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
assert loaded_dataset.meta.total_episodes == num_episodes
assert loaded_dataset.meta.total_frames == sum(frames_per_episode)
cumulative_frames = 0
for episode_idx in range(num_episodes):
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
# Check basic episode properties
assert episode_metadata["episode_index"] == episode_idx
assert episode_metadata["length"] == frames_per_episode[episode_idx]
assert episode_metadata["tasks"] == [tasks[episode_idx]]
# Check dataset indices
expected_from = cumulative_frames
expected_to = cumulative_frames + frames_per_episode[episode_idx]
assert episode_metadata["dataset_from_index"] == expected_from
assert episode_metadata["dataset_to_index"] == expected_to
cumulative_frames += frames_per_episode[episode_idx]
def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factory):
"""Test that episodes have no gaps or overlaps in their data indices."""
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
num_episodes = 5
frames_per_episode = [12, 8, 20, 15, 5]
for episode_idx in range(num_episodes):
for _ in range(frames_per_episode[episode_idx]):
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
dataset.save_episode()
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check data consistency - no gaps or overlaps
cumulative_check = 0
for episode_idx in range(num_episodes):
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
from_idx = episode_metadata["dataset_from_index"]
to_idx = episode_metadata["dataset_to_index"]
# Check that episode starts exactly where previous ended
assert from_idx == cumulative_check, (
f"Episode {episode_idx} starts at {from_idx}, expected {cumulative_check}"
)
# Check that episode length matches expected
actual_length = to_idx - from_idx
expected_length = frames_per_episode[episode_idx]
assert actual_length == expected_length, (
f"Episode {episode_idx} length {actual_length} != expected {expected_length}"
)
cumulative_check = to_idx
# Final check: last episode should end at total frames
expected_total_frames = sum(frames_per_episode)
assert cumulative_check == expected_total_frames, (
f"Final frame count {cumulative_check} != expected {expected_total_frames}"
)
def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory):
"""Test that statistics are properly computed and stored for all features."""
features = {
"state": {"dtype": "float32", "shape": (2,), "names": ["pos", "vel"]},
"action": {"dtype": "float32", "shape": (1,), "names": ["force"]},
}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
# Create controlled data to verify statistics
num_episodes = 2
frames_per_episode = [10, 10]
# Use deterministic data for predictable statistics
torch.manual_seed(42)
for episode_idx in range(num_episodes):
for frame_idx in range(frames_per_episode[episode_idx]):
state_data = torch.tensor([frame_idx * 0.1, frame_idx * 0.2], dtype=torch.float32)
action_data = torch.tensor([frame_idx * 0.05], dtype=torch.float32)
dataset.add_frame({"state": state_data, "action": action_data, "task": "stats_test"})
dataset.save_episode()
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check that statistics exist for all features
assert loaded_dataset.meta.stats is not None, "No statistics found"
for feature_name in features.keys():
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
feature_stats = loaded_dataset.meta.stats[feature_name]
expected_stats = ["min", "max", "mean", "std", "count"]
for stat_key in expected_stats:
assert stat_key in feature_stats, f"Missing '{stat_key}' statistic for '{feature_name}'"
stat_value = feature_stats[stat_key]
# Basic sanity checks
if stat_key == "count":
assert stat_value == sum(frames_per_episode), f"Wrong count for '{feature_name}'"
elif stat_key in ["min", "max", "mean", "std"]:
# Check that statistics are reasonable (not NaN, proper shapes)
if hasattr(stat_value, "shape"):
expected_shape = features[feature_name]["shape"]
assert stat_value.shape == expected_shape or len(stat_value) == expected_shape[0], (
f"Wrong shape for {stat_key} of '{feature_name}'"
)
# Check no NaN values
if hasattr(stat_value, "__iter__"):
assert not any(np.isnan(v) for v in stat_value), f"NaN in {stat_key} for '{feature_name}'"
else:
assert not np.isnan(stat_value), f"NaN in {stat_key} for '{feature_name}'"
def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
"""Test frame indices and episode transitions at episode boundaries."""
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
num_episodes = 3
frames_per_episode = [7, 12, 5]
for episode_idx in range(num_episodes):
for frame_idx in range(frames_per_episode[episode_idx]):
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
dataset.save_episode()
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Test episode boundaries
cumulative = 0
for ep_idx, ep_length in enumerate(frames_per_episode):
if ep_idx > 0:
# Check last frame of previous episode
prev_frame = loaded_dataset[cumulative - 1]
assert prev_frame["episode_index"].item() == ep_idx - 1
# Check first frame of current episode
if cumulative < len(loaded_dataset):
curr_frame = loaded_dataset[cumulative]
assert curr_frame["episode_index"].item() == ep_idx
# Check frame_index within episode
for i in range(ep_length):
if cumulative + i < len(loaded_dataset):
frame = loaded_dataset[cumulative + i]
assert frame["frame_index"].item() == i, f"Frame {cumulative + i} has wrong frame_index"
assert frame["episode_index"].item() == ep_idx, (
f"Frame {cumulative + i} has wrong episode_index"
)
cumulative += ep_length
def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
"""Test that tasks are properly indexed and retrievable."""
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
# Use multiple tasks, including repeated ones
tasks = ["pick", "place", "pick", "navigate", "place"]
unique_tasks = list(set(tasks)) # ["pick", "place", "navigate"]
frames_per_episode = [5, 8, 3, 10, 6]
for episode_idx, task in enumerate(tasks):
for _ in range(frames_per_episode[episode_idx]):
dataset.add_frame({"state": torch.randn(1), "task": task})
dataset.save_episode()
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
# Check that all unique tasks are in the tasks metadata
stored_tasks = set(loaded_dataset.meta.tasks.index)
assert stored_tasks == set(unique_tasks), f"Stored tasks {stored_tasks} != expected {set(unique_tasks)}"
# Check that task indices are consistent
cumulative = 0
for episode_idx, expected_task in enumerate(tasks):
episode_metadata = loaded_dataset.meta.episodes[episode_idx]
assert episode_metadata["tasks"] == [expected_task]
# Check frames in this episode have correct task
for i in range(frames_per_episode[episode_idx]):
frame = loaded_dataset[cumulative + i]
assert frame["task"] == expected_task, f"Frame {cumulative + i} has wrong task"
# Check task_index consistency
expected_task_index = loaded_dataset.meta.get_task_index(expected_task)
assert frame["task_index"].item() == expected_task_index
cumulative += frames_per_episode[episode_idx]
# Check total number of tasks
assert loaded_dataset.meta.total_tasks == len(unique_tasks)

View File

@@ -11,83 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import accumulate
import datasets
import numpy as np
import pyarrow.compute as pc
import pytest
import torch
from lerobot.datasets.utils import (
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
"to": np.array(cumulative_lengths, dtype=np.int64),
}
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index
return _create_synced_timestamps
@pytest.fixture(scope="module")
def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
@@ -136,78 +68,6 @@ def delta_indices_factory():
return _delta_indices
def test_check_timestamps_sync_synced(synced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_single_timestamp():
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx = np.array([0.0]), np.array([0])
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4

View File

@@ -32,7 +32,7 @@ def test_drop_n_first_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_first_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_first_frames=1)
assert sampler.indices == [1, 4, 5]
assert len(sampler) == 3
assert list(sampler) == [1, 4, 5]
@@ -48,7 +48,7 @@ def test_drop_n_last_frames():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, drop_n_last_frames=1)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], drop_n_last_frames=1)
assert sampler.indices == [0, 3, 4]
assert len(sampler) == 3
assert list(sampler) == [0, 3, 4]
@@ -64,7 +64,9 @@ def test_episode_indices_to_use():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, episode_indices_to_use=[0, 2])
sampler = EpisodeAwareSampler(
episode_data_index["from"], episode_data_index["to"], episode_indices_to_use=[0, 2]
)
assert sampler.indices == [0, 1, 3, 4, 5]
assert len(sampler) == 5
assert list(sampler) == [0, 1, 3, 4, 5]
@@ -80,11 +82,11 @@ def test_shuffle():
)
dataset.set_transform(hf_transform_to_torch)
episode_data_index = calculate_episode_data_index(dataset)
sampler = EpisodeAwareSampler(episode_data_index, shuffle=False)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=False)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert list(sampler) == [0, 1, 2, 3, 4, 5]
sampler = EpisodeAwareSampler(episode_data_index, shuffle=True)
sampler = EpisodeAwareSampler(episode_data_index["from"], episode_data_index["to"], shuffle=True)
assert sampler.indices == [0, 1, 2, 3, 4, 5]
assert len(sampler) == 6
assert set(sampler) == {0, 1, 2, 3, 4, 5}

View File

@@ -14,12 +14,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from copy import deepcopy
import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.datasets.utils import (
create_lerobot_dataset_card,
flatten_dict,
hf_transform_to_torch,
unflatten_dict,
)
def test_default_parameters():
@@ -53,3 +61,26 @@ def test_calculate_episode_data_index():
episode_data_index = calculate_episode_data_index(dataset)
assert torch.equal(episode_data_index["from"], torch.tensor([0, 2, 3]))
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_flatten_unflatten_dict():
d = {
"obs": {
"min": 0,
"max": 1,
"mean": 2,
"std": 3,
},
"action": {
"min": 4,
"max": 5,
"mean": 6,
"std": 7,
},
}
original_d = deepcopy(d)
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"

View File

@@ -29,8 +29,8 @@ DUMMY_MOTOR_FEATURES = {
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"laptop": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (64, 96, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import shutil
from functools import partial
from pathlib import Path
from typing import Protocol
@@ -19,19 +20,25 @@ from unittest.mock import patch
import datasets
import numpy as np
import pandas as pd
import PIL.Image
import pytest
import torch
from datasets import Dataset
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
flatten_dict,
get_hf_features_from_features,
hf_transform_to_torch,
)
from lerobot.datasets.video_utils import encode_video_frames
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
@@ -46,10 +53,9 @@ class LeRobotDatasetFactory(Protocol):
def __call__(self, *args, **kwargs) -> LeRobotDataset: ...
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts.values()}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
def get_task_index(tasks: datasets.Dataset, task: str) -> int:
task_idx = tasks.loc[task].task_index.item()
return task_idx
@pytest.fixture(scope="session")
@@ -62,15 +68,49 @@ def img_tensor_factory():
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray:
if content is None:
# Original random noise behavior
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
else:
raise ValueError(dtype)
# Create image with text content using OpenCV
import cv2
# Create white background
img_array = np.ones((height, width, channels), dtype=np.uint8) * 255
# Font settings
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = max(0.5, height / 200) # Scale font with image size
font_color = (0, 0, 0) # Black text
thickness = max(1, int(height / 100))
# Get text size to center it
text_size = cv2.getTextSize(content, font, font_scale, thickness)[0]
text_x = (width - text_size[0]) // 2
text_y = (height + text_size[1]) // 2
# Put text on image
cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness)
# Handle single channel case
if channels == 1:
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
img_array = img_array[:, :, np.newaxis]
# Convert to target dtype
if np.issubdtype(dtype, np.floating):
img_array = img_array.astype(dtype) / 255.0
else:
img_array = img_array.astype(dtype)
return img_array
return _create_img_array
@@ -117,9 +157,10 @@ def info_factory(features_factory):
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
video_files_size_in_mb: float = DEFAULT_VIDEO_FILE_SIZE_IN_MB,
data_path: str = DEFAULT_DATA_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
@@ -133,8 +174,9 @@ def info_factory(features_factory):
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"data_files_size_in_mb": data_files_size_in_mb,
"video_files_size_in_mb": video_files_size_in_mb,
"fps": fps,
"splits": {},
"data_path": data_path,
@@ -175,41 +217,26 @@ def stats_factory():
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_factory(stats_factory):
def _create_episodes_stats(
features: dict[str],
total_episodes: int = 3,
) -> dict:
episodes_stats = {}
for episode_index in range(total_episodes):
episodes_stats[episode_index] = {
"episode_index": episode_index,
"stats": stats_factory(features),
}
return episodes_stats
return _create_episodes_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
tasks[task_index] = task_dict
return tasks
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
ids = list(range(total_tasks))
tasks = [f"Perform action {i}." for i in ids]
df = pd.DataFrame({"task_index": ids}, index=tasks)
return df
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def episodes_factory(tasks_factory, stats_factory):
def _create_episodes(
features: dict[str],
fps: int = DEFAULT_FPS,
total_episodes: int = 3,
total_frames: int = 400,
tasks: dict | None = None,
video_keys: list[str] | None = None,
tasks: pd.DataFrame | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
@@ -217,66 +244,142 @@ def episodes_factory(tasks_factory):
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
if tasks is None:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
num_tasks_available = len(tasks)
if total_episodes < num_tasks_available and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks.values()]
num_tasks_available = len(tasks_list)
# Create empty dictionaries with all keys
d = {
"episode_index": [],
"meta/episodes/chunk_index": [],
"meta/episodes/file_index": [],
"data/chunk_index": [],
"data/file_index": [],
"dataset_from_index": [],
"dataset_to_index": [],
"tasks": [],
"length": [],
}
if video_keys is not None:
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"] = []
d[f"videos/{video_key}/file_index"] = []
d[f"videos/{video_key}/from_timestamp"] = []
d[f"videos/{video_key}/to_timestamp"] = []
episodes = {}
remaining_tasks = tasks_list.copy()
for stats_key in flatten_dict({"stats": stats_factory(features)}):
d[stats_key] = []
num_frames = 0
remaining_tasks = list(tasks.index)
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
tasks_to_sample = remaining_tasks if len(remaining_tasks) > 0 else list(tasks.index)
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes[ep_idx] = {
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
d["episode_index"].append(ep_idx)
# TODO(rcadene): remove heuristic of only one file
d["meta/episodes/chunk_index"].append(0)
d["meta/episodes/file_index"].append(0)
d["data/chunk_index"].append(0)
d["data/file_index"].append(0)
d["dataset_from_index"].append(num_frames)
d["dataset_to_index"].append(num_frames + lengths[ep_idx])
d["tasks"].append(episode_tasks)
d["length"].append(lengths[ep_idx])
return episodes
if video_keys is not None:
for video_key in video_keys:
d[f"videos/{video_key}/chunk_index"].append(0)
d[f"videos/{video_key}/file_index"].append(0)
d[f"videos/{video_key}/from_timestamp"].append(num_frames / fps)
d[f"videos/{video_key}/to_timestamp"].append((num_frames + lengths[ep_idx]) / fps)
# Add stats columns like "stats/action/max"
for stats_key, stats in flatten_dict({"stats": stats_factory(features)}).items():
d[stats_key].append(stats)
num_frames += lengths[ep_idx]
return Dataset.from_dict(d)
return _create_episodes
@pytest.fixture(scope="session")
def create_videos(info_factory, img_array_factory):
def _create_video_directory(
root: Path,
info: dict | None = None,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
):
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
for key, ft in video_feats.items():
# create and save images with identifiable content
tmp_dir = root / "tmp_images"
tmp_dir.mkdir(parents=True, exist_ok=True)
for frame_index in range(info["total_frames"]):
content = f"{key}-{frame_index}"
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
pil_img = PIL.Image.fromarray(img)
path = tmp_dir / f"frame-{frame_index:06d}.png"
pil_img.save(path)
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
video_path.parent.mkdir(parents=True, exist_ok=True)
# Use the global fps from info, not video-specific fps which might not exist
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
shutil.rmtree(tmp_dir)
return _create_video_directory
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
if tasks is None:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
if features is None:
features = features_factory()
if episodes is None:
episodes = episodes_factory(features, fps)
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
# Slightly incorrect, but for simplicity, we assign to all frames the first task defined in the episode metadata.
# TODO(rcadene): assign the tasks of the episode per chunks of frames
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
@@ -286,8 +389,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
for _ in range(len(index_col))
img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}")
for i in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
@@ -314,7 +417,6 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
@@ -324,29 +426,29 @@ def lerobot_dataset_metadata_factory(
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
) -> LeRobotDatasetMetadata:
if not info:
if info is None:
info = info_factory()
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episodes,
)
@@ -366,7 +468,6 @@ def lerobot_dataset_metadata_factory(
def lerobot_dataset_factory(
info_factory,
stats_factory,
episodes_stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
@@ -380,50 +481,63 @@ def lerobot_dataset_factory(
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
use_videos: bool = True,
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes_metadata: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
chunks_size: int = DEFAULT_CHUNK_SIZE,
**kwargs,
) -> LeRobotDataset:
if not info:
# Instantiate objects
if info is None:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
total_episodes=total_episodes,
total_frames=total_frames,
total_tasks=total_tasks,
use_videos=use_videos,
data_files_size_in_mb=data_files_size_in_mb,
chunks_size=chunks_size,
)
if not stats:
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
if episodes_metadata is None:
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
episodes_metadata = episodes_factory(
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
video_keys=video_keys,
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
if hf_dataset is None:
hf_dataset = hf_dataset_factory(
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
)
# Write data on disk
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
hf_dataset=hf_dataset,
data_files_size_in_mb=data_files_size_in_mb,
chunks_size=chunks_size,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info=info,
stats=stats,
episodes_stats=episodes_stats,
tasks=tasks,
episodes=episode_dicts,
episodes=episodes_metadata,
)
with (
patch("lerobot.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,

View File

@@ -11,137 +11,166 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import numpy as np
import pandas as pd
import pytest
from datasets import Dataset
from lerobot.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
get_hf_dataset_size_in_mb,
update_chunk_file_indices,
write_episodes,
write_info,
write_stats,
write_tasks,
)
def write_hf_dataset(
hf_dataset: Dataset,
local_dir: Path,
data_file_size_mb: float | None = None,
chunk_size: int | None = None,
):
"""
Writes a Hugging Face Dataset to one or more Parquet files in a structured directory format.
If the dataset size is within `DEFAULT_DATA_FILE_SIZE_IN_MB`, it's saved as a single file.
Otherwise, the dataset is split into multiple smaller Parquet files, each not exceeding the size limit.
The file and chunk indices are managed to organize the output files in a hierarchical structure,
e.g., `data/chunk-000/file-000.parquet`, `data/chunk-000/file-001.parquet`, etc.
This function ensures that episodes are not split across multiple files.
Args:
hf_dataset (Dataset): The Hugging Face Dataset to be written to disk.
local_dir (Path): The root directory where the dataset files will be stored.
data_file_size_mb (float, optional): Maximal size for the parquet data file, in MB. Defaults to DEFAULT_DATA_FILE_SIZE_IN_MB.
chunk_size (int, optional): Maximal number of files within a chunk folder before creating another one. Defaults to DEFAULT_CHUNK_SIZE.
"""
if data_file_size_mb is None:
data_file_size_mb = DEFAULT_DATA_FILE_SIZE_IN_MB
if chunk_size is None:
chunk_size = DEFAULT_CHUNK_SIZE
dataset_size_in_mb = get_hf_dataset_size_in_mb(hf_dataset)
if dataset_size_in_mb <= data_file_size_mb:
# If the dataset is small enough, write it to a single file.
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
path.parent.mkdir(parents=True, exist_ok=True)
hf_dataset.to_parquet(path)
return
# If the dataset is too large, split it into smaller chunks, keeping episodes whole.
episode_indices = np.array(hf_dataset["episode_index"])
episode_boundaries = np.where(np.diff(episode_indices) != 0)[0] + 1
episode_starts = np.concatenate(([0], episode_boundaries))
episode_ends = np.concatenate((episode_boundaries, [len(hf_dataset)]))
num_episodes = len(episode_starts)
current_episode_idx = 0
chunk_idx, file_idx = 0, 0
while current_episode_idx < num_episodes:
shard_start_row = episode_starts[current_episode_idx]
shard_end_row = episode_ends[current_episode_idx]
next_episode_to_try_idx = current_episode_idx + 1
while next_episode_to_try_idx < num_episodes:
potential_shard_end_row = episode_ends[next_episode_to_try_idx]
dataset_shard_candidate = hf_dataset.select(range(shard_start_row, potential_shard_end_row))
shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard_candidate)
if shard_size_mb > data_file_size_mb:
break
else:
shard_end_row = potential_shard_end_row
next_episode_to_try_idx += 1
dataset_shard = hf_dataset.select(range(shard_start_row, shard_end_row))
if (
shard_start_row == episode_starts[current_episode_idx]
and shard_end_row == episode_ends[current_episode_idx]
):
shard_size_mb = get_hf_dataset_size_in_mb(dataset_shard)
if shard_size_mb > data_file_size_mb:
logging.warning(
f"Episode with index {hf_dataset[shard_start_row.item()]['episode_index']} has size {shard_size_mb:.2f}MB, "
f"which is larger than data_file_size_mb ({data_file_size_mb}MB). "
"Writing it to a separate shard anyway to preserve episode integrity."
)
# Define the path for the current shard and ensure the directory exists.
path = local_dir / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
path.parent.mkdir(parents=True, exist_ok=True)
# Write the shard to a Parquet file.
dataset_shard.to_parquet(path)
# Update chunk and file indices for the next iteration.
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
current_episode_idx = next_episode_to_try_idx
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
def create_info(info_factory):
def _create_info(dir: Path, info: dict | None = None):
if info is None:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
write_info(info, dir)
return _create_info_json_file
return _create_info
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
def create_stats(stats_factory):
def _create_stats(dir: Path, stats: dict | None = None):
if stats is None:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
write_stats(stats, dir)
return _create_stats_json_file
return _create_stats
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes_stats.values())
return fpath
return _create_episodes_stats_jsonl_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
def create_tasks(tasks_factory):
def _create_tasks(dir: Path, tasks: pd.DataFrame | None = None):
if tasks is None:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks.values())
return fpath
write_tasks(tasks, dir)
return _create_tasks_jsonl_file
return _create_tasks
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
def create_episodes(episodes_factory):
def _create_episodes(dir: Path, episodes: datasets.Dataset | None = None):
if episodes is None:
# TODO(rcadene): add features, fps as arguments
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes.values())
return fpath
write_episodes(episodes, dir)
return _create_episodes_jsonl_file
return _create_episodes
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
def create_hf_dataset(hf_dataset_factory):
def _create_hf_dataset(
dir: Path,
hf_dataset: datasets.Dataset | None = None,
data_file_size_in_mb: float | None = None,
chunk_size: int | None = None,
):
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
write_hf_dataset(hf_dataset, dir, data_file_size_in_mb, chunk_size)
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet
return _create_hf_dataset

134
tests/fixtures/hub.py vendored
View File

@@ -14,15 +14,19 @@
from pathlib import Path
import datasets
import pandas as pd
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.datasets.utils import (
EPISODES_PATH,
EPISODES_STATS_PATH,
DEFAULT_CHUNK_SIZE,
DEFAULT_DATA_FILE_SIZE_IN_MB,
DEFAULT_DATA_PATH,
DEFAULT_EPISODES_PATH,
DEFAULT_TASKS_PATH,
DEFAULT_VIDEO_PATH,
INFO_PATH,
STATS_PATH,
TASKS_PATH,
)
from tests.fixtures.constants import LEROBOT_TEST_DIR
@@ -30,17 +34,16 @@ from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info_factory,
info_path,
create_info,
stats_factory,
stats_path,
episodes_stats_factory,
episodes_stats_path,
create_stats,
tasks_factory,
tasks_path,
create_tasks,
episodes_factory,
episode_path,
single_episode_parquet_path,
create_episodes,
hf_dataset_factory,
create_hf_dataset,
create_videos,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
@@ -50,82 +53,93 @@ def mock_snapshot_download_factory(
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
episodes_stats: list[dict] | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
tasks: pd.DataFrame | None = None,
episodes: datasets.Dataset | None = None,
hf_dataset: datasets.Dataset | None = None,
data_files_size_in_mb: float = DEFAULT_DATA_FILE_SIZE_IN_MB,
chunks_size: int = DEFAULT_CHUNK_SIZE,
):
if not info:
info = info_factory()
if not stats:
if info is None:
info = info_factory(data_files_size_in_mb=data_files_size_in_mb, chunks_size=chunks_size)
if stats is None:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=info["total_episodes"]
)
if not tasks:
if tasks is None:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
if episodes is None:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
features=info["features"],
fps=info["fps"],
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
)
if not hf_dataset:
if hf_dataset is None:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
repo_id: str, # TODO(rcadene): repo_id should be used no?
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
if local_dir is None:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
all_files = [
INFO_PATH,
STATS_PATH,
# TODO(rcadene): remove naive chunk 0 file 0 ?
DEFAULT_TASKS_PATH.format(chunk_index=0, file_index=0),
DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0),
DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0),
]
data_files = []
for episode_dict in episodes.values():
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
video_keys = [key for key, feats in info["features"].items() if feats["dtype"] == "video"]
for key in video_keys:
all_files.append(DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0))
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
request_info = False
request_tasks = False
request_episodes = False
request_stats = False
request_data = False
request_videos = False
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == EPISODES_STATS_PATH:
_ = episodes_stats_path(local_dir, episodes_stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
if rel_path.startswith("meta/info.json"):
request_info = True
elif rel_path.startswith("meta/stats"):
request_stats = True
elif rel_path.startswith("meta/tasks"):
request_tasks = True
elif rel_path.startswith("meta/episodes"):
request_episodes = True
elif rel_path.startswith("data/"):
request_data = True
elif rel_path.startswith("videos/"):
request_videos = True
else:
pass
raise ValueError(f"{rel_path} not supported.")
if request_info:
create_info(local_dir, info)
if request_stats:
create_stats(local_dir, stats)
if request_tasks:
create_tasks(local_dir, tasks)
if request_episodes:
create_episodes(local_dir, episodes)
if request_data:
create_hf_dataset(local_dir, hf_dataset, data_files_size_in_mb, chunks_size)
if request_videos:
create_videos(root=local_dir, info=info)
return str(local_dir)
return _mock_snapshot_download

View File

@@ -71,7 +71,11 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_episodes=1,
total_frames=1,
total_tasks=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@@ -140,7 +144,6 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
Note: We test various combinations of policy and dataset. The combinations are by no means exhaustive,
and for now we add tests as we see fit.
"""
train_cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]),

View File

@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from lerobot.calibrate import CalibrateConfig, calibrate
from lerobot.record import DatasetRecordConfig, RecordConfig, record
from lerobot.replay import DatasetReplayConfig, ReplayConfig, replay
@@ -67,7 +69,14 @@ def test_record_and_resume(tmp_path):
assert dataset.meta.total_tasks == 1
cfg.resume = True
dataset = record(cfg)
# Mock the revision to prevent Hub calls during resume
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record")
dataset = record(cfg)
assert dataset.meta.total_episodes == dataset.num_episodes == 2
assert dataset.meta.total_frames == dataset.num_frames == 6
@@ -103,4 +112,12 @@ def test_record_and_replay(tmp_path):
)
record(record_cfg)
replay(replay_cfg)
# Mock the revision to prevent Hub calls during replay
with (
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
):
mock_get_safe_version.return_value = "v3.0"
mock_snapshot_download.return_value = str(tmp_path / "record_and_replay")
replay(replay_cfg)

View File

@@ -384,7 +384,7 @@ def test_to_lerobot_dataset(tmp_path):
elif feature == "next.done":
assert torch.equal(value, buffer.dones[i])
elif feature == "observation.image":
# Tenssor -> numpy is not precise, so we have some diff there
# Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
elif feature == "observation.state":