From 12f2f3576096cf5b3b0b8fe5d76906d7a1d7c926 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Tue, 21 Oct 2025 16:17:12 +0200 Subject: [PATCH] - Introduce _current_file_start_frame for better tracking of the number of frames in each parquet file (#2280) - Added testing for that section in `test_datasets.py` --- src/lerobot/datasets/lerobot_dataset.py | 7 +- tests/datasets/test_datasets.py | 93 +++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 9bbe07a5..6258312a 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -686,6 +686,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer = None self.writer = None self.latest_episode = None + self._current_file_start_frame = None # Track the starting frame index of the current parquet file self.root.mkdir(exist_ok=True, parents=True) @@ -1232,6 +1233,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # Initialize indices and frame count for a new dataset made of the first episode data chunk_idx, file_idx = 0, 0 global_frame_index = 0 + self._current_file_start_frame = 0 # However, if the episodes already exists # It means we are resuming recording, so we need to load the latest episode # Update the indices to avoid overwriting the latest episode @@ -1243,6 +1245,7 @@ class LeRobotDataset(torch.utils.data.Dataset): # When resuming, move to the next file chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) + self._current_file_start_frame = global_frame_index else: # Retrieve information from the latest parquet file latest_ep = self.latest_episode @@ -1253,7 +1256,7 @@ class LeRobotDataset(torch.utils.data.Dataset): latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx) latest_size_in_mb = get_file_size_in_mb(latest_path) - frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"] + frames_in_current_file = global_frame_index - self._current_file_start_frame av_size_per_frame = ( latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0 ) @@ -1267,6 +1270,7 @@ class LeRobotDataset(torch.utils.data.Dataset): chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size) self._close_writer() self._writer_closed_for_reading = False + self._current_file_start_frame = global_frame_index ep_dict["data/chunk_index"] = chunk_idx ep_dict["data/file_index"] = file_idx @@ -1473,6 +1477,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec() obj.writer = None obj.latest_episode = None + obj._current_file_start_frame = None # Initialize tracking for incremental recording obj._lazy_loading = False obj._recorded_frames = 0 diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index e174c578..38fdc358 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1199,3 +1199,96 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory): expected_to = (ep_idx + 1) * frames_per_episode assert ep_metadata["dataset_from_index"] == expected_from assert ep_metadata["dataset_to_index"] == expected_to + + +def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_factory): + """Regression test for bug where frames_in_current_file only counted frames from last episode instead of all frames in current file.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + dataset.meta.update_chunk_settings(data_files_size_in_mb=100) + + assert dataset._current_file_start_frame is None + + frames_per_episode = 10 + for _ in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.randn(2), + "action": torch.randn(2), + "task": "task_0", + } + ) + dataset.save_episode() + + assert dataset._current_file_start_frame == 0 + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == frames_per_episode + + for _ in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.randn(2), + "action": torch.randn(2), + "task": "task_1", + } + ) + dataset.save_episode() + + assert dataset._current_file_start_frame == 0 + assert dataset.meta.total_episodes == 2 + assert dataset.meta.total_frames == 2 * frames_per_episode + + ep1_chunk = dataset.latest_episode["data/chunk_index"] + ep1_file = dataset.latest_episode["data/file_index"] + assert ep1_chunk == 0 + assert ep1_file == 0 + + for _ in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.randn(2), + "action": torch.randn(2), + "task": "task_2", + } + ) + dataset.save_episode() + + assert dataset._current_file_start_frame == 0 + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == 3 * frames_per_episode + + ep2_chunk = dataset.latest_episode["data/chunk_index"] + ep2_file = dataset.latest_episode["data/file_index"] + assert ep2_chunk == 0 + assert ep2_file == 0 + + dataset.finalize() + + from lerobot.datasets.utils import load_episodes + + dataset.meta.episodes = load_episodes(dataset.root) + assert dataset.meta.episodes is not None + + for ep_idx in range(3): + ep_metadata = dataset.meta.episodes[ep_idx] + assert ep_metadata["data/chunk_index"] == 0 + assert ep_metadata["data/file_index"] == 0 + + expected_from = ep_idx * frames_per_episode + expected_to = (ep_idx + 1) * frames_per_episode + assert ep_metadata["dataset_from_index"] == expected_from + assert ep_metadata["dataset_to_index"] == expected_to + + loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root) + assert len(loaded_dataset) == 3 * frames_per_episode + assert loaded_dataset.meta.total_episodes == 3 + assert loaded_dataset.meta.total_frames == 3 * frames_per_episode + + for idx in range(len(loaded_dataset)): + frame = loaded_dataset[idx] + expected_ep = idx // frames_per_episode + assert frame["episode_index"].item() == expected_ep