- Introduce _current_file_start_frame for better tracking of the number of frames in each parquet file (#2280)
- Added testing for that section in `test_datasets.py`
This commit is contained in:
@@ -686,6 +686,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.episode_buffer = None
|
self.episode_buffer = None
|
||||||
self.writer = None
|
self.writer = None
|
||||||
self.latest_episode = None
|
self.latest_episode = None
|
||||||
|
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||||
|
|
||||||
self.root.mkdir(exist_ok=True, parents=True)
|
self.root.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
@@ -1232,6 +1233,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# Initialize indices and frame count for a new dataset made of the first episode data
|
# Initialize indices and frame count for a new dataset made of the first episode data
|
||||||
chunk_idx, file_idx = 0, 0
|
chunk_idx, file_idx = 0, 0
|
||||||
global_frame_index = 0
|
global_frame_index = 0
|
||||||
|
self._current_file_start_frame = 0
|
||||||
# However, if the episodes already exists
|
# However, if the episodes already exists
|
||||||
# It means we are resuming recording, so we need to load the latest episode
|
# It means we are resuming recording, so we need to load the latest episode
|
||||||
# Update the indices to avoid overwriting the latest episode
|
# Update the indices to avoid overwriting the latest episode
|
||||||
@@ -1243,6 +1245,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
# When resuming, move to the next file
|
# When resuming, move to the next file
|
||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
|
self._current_file_start_frame = global_frame_index
|
||||||
else:
|
else:
|
||||||
# Retrieve information from the latest parquet file
|
# Retrieve information from the latest parquet file
|
||||||
latest_ep = self.latest_episode
|
latest_ep = self.latest_episode
|
||||||
@@ -1253,7 +1256,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
latest_path = self.root / self.meta.data_path.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||||
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
latest_size_in_mb = get_file_size_in_mb(latest_path)
|
||||||
|
|
||||||
frames_in_current_file = global_frame_index - latest_ep["dataset_from_index"]
|
frames_in_current_file = global_frame_index - self._current_file_start_frame
|
||||||
av_size_per_frame = (
|
av_size_per_frame = (
|
||||||
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
latest_size_in_mb / frames_in_current_file if frames_in_current_file > 0 else 0
|
||||||
)
|
)
|
||||||
@@ -1267,6 +1270,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, self.meta.chunks_size)
|
||||||
self._close_writer()
|
self._close_writer()
|
||||||
self._writer_closed_for_reading = False
|
self._writer_closed_for_reading = False
|
||||||
|
self._current_file_start_frame = global_frame_index
|
||||||
|
|
||||||
ep_dict["data/chunk_index"] = chunk_idx
|
ep_dict["data/chunk_index"] = chunk_idx
|
||||||
ep_dict["data/file_index"] = file_idx
|
ep_dict["data/file_index"] = file_idx
|
||||||
@@ -1473,6 +1477,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
obj.writer = None
|
obj.writer = None
|
||||||
obj.latest_episode = None
|
obj.latest_episode = None
|
||||||
|
obj._current_file_start_frame = None
|
||||||
# Initialize tracking for incremental recording
|
# Initialize tracking for incremental recording
|
||||||
obj._lazy_loading = False
|
obj._lazy_loading = False
|
||||||
obj._recorded_frames = 0
|
obj._recorded_frames = 0
|
||||||
|
|||||||
@@ -1199,3 +1199,96 @@ def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
|
|||||||
expected_to = (ep_idx + 1) * frames_per_episode
|
expected_to = (ep_idx + 1) * frames_per_episode
|
||||||
assert ep_metadata["dataset_from_index"] == expected_from
|
assert ep_metadata["dataset_from_index"] == expected_from
|
||||||
assert ep_metadata["dataset_to_index"] == expected_to
|
assert ep_metadata["dataset_to_index"] == expected_to
|
||||||
|
|
||||||
|
|
||||||
|
def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
"""Regression test for bug where frames_in_current_file only counted frames from last episode instead of all frames in current file."""
|
||||||
|
features = {
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||||
|
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
|
||||||
|
}
|
||||||
|
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||||
|
dataset.meta.update_chunk_settings(data_files_size_in_mb=100)
|
||||||
|
|
||||||
|
assert dataset._current_file_start_frame is None
|
||||||
|
|
||||||
|
frames_per_episode = 10
|
||||||
|
for _ in range(frames_per_episode):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"observation.state": torch.randn(2),
|
||||||
|
"action": torch.randn(2),
|
||||||
|
"task": "task_0",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset._current_file_start_frame == 0
|
||||||
|
assert dataset.meta.total_episodes == 1
|
||||||
|
assert dataset.meta.total_frames == frames_per_episode
|
||||||
|
|
||||||
|
for _ in range(frames_per_episode):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"observation.state": torch.randn(2),
|
||||||
|
"action": torch.randn(2),
|
||||||
|
"task": "task_1",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset._current_file_start_frame == 0
|
||||||
|
assert dataset.meta.total_episodes == 2
|
||||||
|
assert dataset.meta.total_frames == 2 * frames_per_episode
|
||||||
|
|
||||||
|
ep1_chunk = dataset.latest_episode["data/chunk_index"]
|
||||||
|
ep1_file = dataset.latest_episode["data/file_index"]
|
||||||
|
assert ep1_chunk == 0
|
||||||
|
assert ep1_file == 0
|
||||||
|
|
||||||
|
for _ in range(frames_per_episode):
|
||||||
|
dataset.add_frame(
|
||||||
|
{
|
||||||
|
"observation.state": torch.randn(2),
|
||||||
|
"action": torch.randn(2),
|
||||||
|
"task": "task_2",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
assert dataset._current_file_start_frame == 0
|
||||||
|
assert dataset.meta.total_episodes == 3
|
||||||
|
assert dataset.meta.total_frames == 3 * frames_per_episode
|
||||||
|
|
||||||
|
ep2_chunk = dataset.latest_episode["data/chunk_index"]
|
||||||
|
ep2_file = dataset.latest_episode["data/file_index"]
|
||||||
|
assert ep2_chunk == 0
|
||||||
|
assert ep2_file == 0
|
||||||
|
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
from lerobot.datasets.utils import load_episodes
|
||||||
|
|
||||||
|
dataset.meta.episodes = load_episodes(dataset.root)
|
||||||
|
assert dataset.meta.episodes is not None
|
||||||
|
|
||||||
|
for ep_idx in range(3):
|
||||||
|
ep_metadata = dataset.meta.episodes[ep_idx]
|
||||||
|
assert ep_metadata["data/chunk_index"] == 0
|
||||||
|
assert ep_metadata["data/file_index"] == 0
|
||||||
|
|
||||||
|
expected_from = ep_idx * frames_per_episode
|
||||||
|
expected_to = (ep_idx + 1) * frames_per_episode
|
||||||
|
assert ep_metadata["dataset_from_index"] == expected_from
|
||||||
|
assert ep_metadata["dataset_to_index"] == expected_to
|
||||||
|
|
||||||
|
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||||
|
assert len(loaded_dataset) == 3 * frames_per_episode
|
||||||
|
assert loaded_dataset.meta.total_episodes == 3
|
||||||
|
assert loaded_dataset.meta.total_frames == 3 * frames_per_episode
|
||||||
|
|
||||||
|
for idx in range(len(loaded_dataset)):
|
||||||
|
frame = loaded_dataset[idx]
|
||||||
|
expected_ep = idx // frames_per_episode
|
||||||
|
assert frame["episode_index"].item() == expected_ep
|
||||||
|
|||||||
Reference in New Issue
Block a user