- Introduce _current_file_start_frame for better tracking of the number of frames in each parquet file (#2280)

- Added testing for that section in `test_datasets.py`
This commit is contained in:
Michel Aractingi
2025-10-21 16:17:12 +02:00
committed by GitHub
parent a024d33750
commit 12f2f35760
2 changed files with 99 additions and 1 deletions

View File

@@ -1199,3 +1199,96 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
expected_to = (ep_idx + 1) * frames_per_episode
assert ep_metadata["dataset_from_index"] == expected_from
assert ep_metadata["dataset_to_index"] == expected_to
def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_factory):
"""Regression test for bug where frames_in_current_file only counted frames from last episode instead of all frames in current file."""
features = {
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
assert dataset._current_file_start_frame is None
frames_per_episode = 10
for _ in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.randn(2),
"action": torch.randn(2),
"task": "task_0",
}
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.meta.total_episodes == 1
assert dataset.meta.total_frames == frames_per_episode
for _ in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.randn(2),
"action": torch.randn(2),
"task": "task_1",
}
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.meta.total_episodes == 2
assert dataset.meta.total_frames == 2 * frames_per_episode
ep1_chunk = dataset.latest_episode["data/chunk_index"]
ep1_file = dataset.latest_episode["data/file_index"]
assert ep1_chunk == 0
assert ep1_file == 0
for _ in range(frames_per_episode):
dataset.add_frame(
{
"observation.state": torch.randn(2),
"action": torch.randn(2),
"task": "task_2",
}
)
dataset.save_episode()
assert dataset._current_file_start_frame == 0
assert dataset.meta.total_episodes == 3
assert dataset.meta.total_frames == 3 * frames_per_episode
ep2_chunk = dataset.latest_episode["data/chunk_index"]
ep2_file = dataset.latest_episode["data/file_index"]
assert ep2_chunk == 0
assert ep2_file == 0
dataset.finalize()
from lerobot.datasets.utils import load_episodes
dataset.meta.episodes = load_episodes(dataset.root)
assert dataset.meta.episodes is not None
for ep_idx in range(3):
ep_metadata = dataset.meta.episodes[ep_idx]
assert ep_metadata["data/chunk_index"] == 0
assert ep_metadata["data/file_index"] == 0
expected_from = ep_idx * frames_per_episode
expected_to = (ep_idx + 1) * frames_per_episode
assert ep_metadata["dataset_from_index"] == expected_from
assert ep_metadata["dataset_to_index"] == expected_to
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
assert len(loaded_dataset) == 3 * frames_per_episode
assert loaded_dataset.meta.total_episodes == 3
assert loaded_dataset.meta.total_frames == 3 * frames_per_episode
for idx in range(len(loaded_dataset)):
frame = loaded_dataset[idx]
expected_ep = idx // frames_per_episode
assert frame["episode_index"].item() == expected_ep