Dataset v3 (#1412)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Remi Cadene <re.cadene@gmail.com>
Co-authored-by: Tavish <tavish9.chen@gmail.com>
Co-authored-by: fracapuano <francesco.capuano@huggingface.co>
Co-authored-by: CarolinePascal <caroline8.pascal@gmail.com>
This commit is contained in:
Michel Aractingi
2025-09-15 09:53:30 +02:00
committed by GitHub
parent d602e8169c
commit f55c6e89f0
50 changed files with 4642 additions and 4092 deletions

View File

@@ -11,83 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from itertools import accumulate
import datasets
import numpy as np
import pyarrow.compute as pc
import pytest
import torch
from lerobot.datasets.utils import (
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
def calculate_total_episode(
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
) -> dict[str, torch.Tensor]:
episode_indices = sorted(hf_dataset.unique("episode_index"))
total_episodes = len(episode_indices)
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
raise ValueError("episode_index values are not sorted and contiguous.")
return total_episodes
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.ndarray]:
episode_lengths = []
table = hf_dataset.data.table
total_episodes = calculate_total_episode(hf_dataset)
for ep_idx in range(total_episodes):
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
episode_lengths.insert(ep_idx, len(ep_table))
cumulative_lengths = list(accumulate(episode_lengths))
return {
"from": np.array([0] + cumulative_lengths[:-1], dtype=np.int64),
"to": np.array(cumulative_lengths, dtype=np.int64),
}
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
episode_data_index = calculate_episode_data_index(hf_dataset)
return timestamps, episode_indices, episode_data_index
return _create_synced_timestamps
@pytest.fixture(scope="module")
def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_unsynced_timestamps
@pytest.fixture(scope="module")
def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
@@ -136,78 +68,6 @@ def delta_indices_factory():
return _delta_indices
def test_check_timestamps_sync_synced(synced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = synced_timestamps_factory(fps)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = unsynced_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=ep_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_single_timestamp():
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx = np.array([0.0]), np.array([0])
episode_data_index = {"to": np.array([1]), "from": np.array([0])}
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4