LeRobotDataset v2.1 (#711)
Co-authored-by: Remi <remi.cadene@huggingface.co> Co-authored-by: Remi Cadene <re.cadene@gmail.com>
This commit is contained in:
@@ -1,55 +1,78 @@
|
||||
from itertools import accumulate
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_hf_dataset_factory(hf_dataset_factory):
|
||||
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
|
||||
return hf_dataset_factory(fps=fps)
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
return _create_synced_hf_dataset
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": np.array([0] + cumulative_lenghts[:-1], dtype=np.int64),
|
||||
"to": np.array(cumulative_lenghts, dtype=np.int64),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just outside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
def synced_timestamps_factory(hf_dataset_factory):
|
||||
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
hf_dataset = hf_dataset_factory(fps=fps)
|
||||
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_unsynced_hf_dataset
|
||||
return _create_synced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just inside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
def unsynced_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_unsynced_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_hf_dataset
|
||||
return _create_unsynced_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_timestamps_factory(synced_timestamps_factory):
|
||||
def _create_slightly_off_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
|
||||
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
|
||||
return timestamps, episode_indices, episode_data_index
|
||||
|
||||
return _create_slightly_off_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -100,42 +123,42 @@ def delta_indices_factory():
|
||||
return _delta_indices
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
def test_check_timestamps_sync_synced(synced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
|
||||
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
|
||||
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
@@ -143,14 +166,14 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
|
||||
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=ep_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
@@ -158,33 +181,13 @@ def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
|
||||
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
timestamps, ep_idx = np.array([0.0]), np.array([0])
|
||||
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=single_timestamp_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
|
||||
@pytest.mark.skip("TODO: fix")
|
||||
def test_check_timestamps_sync_empty_dataset():
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
|
||||
empty_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"to": torch.tensor([], dtype=torch.int64),
|
||||
"from": torch.tensor([], dtype=torch.int64),
|
||||
}
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=empty_hf_dataset,
|
||||
timestamps=timestamps,
|
||||
episode_indices=ep_idx,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
|
||||
Reference in New Issue
Block a user