Incremental parquet writing (#1903)
* incremental parquet writing * add .finalise() and a backup __del__ for stopping writers * fix missing import * precommit fixes added back the use of embed images * added lazy loading for hf_Dataset to avoid frequently reloading the dataset during recording * fix bug in video timestamps * Added proper closing of parquet file before reading * Added rigorous testing to validate the consistency of the meta data after creation of a new dataset * fix bug in episode index during clear_episode_buffer * fix(empty concat): check for empty paths list before data files concatenation * fix(v3.0 message): updating v3.0 backward compatibility message. * added fixes for the resume logic * answering co-pilot review * reverting some changes and style nits * removed unused functions * fix chunk_id and file_id when resuming * - fix parquet loading when resuming - add test to verify the parquet file integrity when resuming so that data files are now overwritten * added general function get_file_size_in_mb and removed the one for video * fix table size value when resuming * Remove unnecessary reloading of the parquet file when resuming record. Write to a new parquet file when resuming record * added back reading parquet file for image datasets only * - respond to Qlhoest comments - Use pyarrows `from_pydict` function - Add buffer for episode metadata to write to the parquet file in batches to improve efficiency - Remove the use of `to_parquet_with_hf_images` * fix(dataset_tools) with the new logic using proper finalize bug in finding the latest path of the metdata that was pointing to the data files added check for the metadata size in the case the metadatabuffer was not written yet * nit in flush_metadata_buffer * fix(lerobot_dataset) return the right dataset len when a subset of the dataset is requested --------- Co-authored-by: Harsimrat Sandhawalia <hs.sandhawalia@gmail.com>
This commit is contained in:
@@ -806,6 +806,8 @@ def test_episode_index_distribution(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.randn(2), "task": f"task_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
# Load the dataset and check episode indices
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -855,6 +857,8 @@ def test_multi_episode_metadata_consistency(tmp_path, empty_lerobot_dataset_fact
|
||||
dataset.add_frame({"state": torch.randn(3), ACTION: torch.randn(2), "task": tasks[episode_idx]})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
# Load and validate episode metadata
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
@@ -893,6 +897,8 @@ def test_data_consistency_across_episodes(tmp_path, empty_lerobot_dataset_factor
|
||||
dataset.add_frame({"state": torch.randn(1), "task": "consistency_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check data consistency - no gaps or overlaps
|
||||
@@ -944,6 +950,8 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
dataset.add_frame({"state": state_data, ACTION: action_data, "task": "stats_test"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that statistics exist for all features
|
||||
@@ -989,6 +997,8 @@ def test_episode_boundary_integrity(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.tensor([float(frame_idx)]), "task": f"episode_{episode_idx}"})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Test episode boundaries
|
||||
@@ -1031,6 +1041,8 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
dataset.add_frame({"state": torch.randn(1), "task": task})
|
||||
dataset.save_episode()
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
loaded_dataset = LeRobotDataset(dataset.repo_id, root=dataset.root)
|
||||
|
||||
# Check that all unique tasks are in the tasks metadata
|
||||
@@ -1056,3 +1068,134 @@ def test_task_indexing_and_validation(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
# Check total number of tasks
|
||||
assert loaded_dataset.meta.total_tasks == len(unique_tasks)
|
||||
|
||||
|
||||
def test_dataset_resume_recording(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that resuming dataset recording preserves previously recorded episodes.
|
||||
|
||||
This test validates the critical resume functionality by:
|
||||
1. Recording initial episodes and finalizing
|
||||
2. Reopening the dataset
|
||||
3. Recording additional episodes
|
||||
4. Verifying all data (old + new) is intact
|
||||
|
||||
This specifically tests the bug fix where parquet files were being overwritten
|
||||
instead of appended to during resume.
|
||||
"""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
initial_episodes = 2
|
||||
frames_per_episode = 3
|
||||
|
||||
for ep_idx in range(initial_episodes):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
|
||||
"action": torch.tensor([0.5, 0.5]),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == initial_episodes
|
||||
assert dataset.meta.total_frames == initial_episodes * frames_per_episode
|
||||
|
||||
dataset.finalize()
|
||||
initial_root = dataset.root
|
||||
initial_repo_id = dataset.repo_id
|
||||
del dataset
|
||||
|
||||
dataset_verify = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
assert dataset_verify.meta.total_episodes == initial_episodes
|
||||
assert dataset_verify.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert len(dataset_verify.hf_dataset) == initial_episodes * frames_per_episode
|
||||
|
||||
for idx in range(len(dataset_verify.hf_dataset)):
|
||||
item = dataset_verify[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
expected_frame = idx % frames_per_episode
|
||||
assert item["episode_index"].item() == expected_ep
|
||||
assert item["frame_index"].item() == expected_frame
|
||||
assert item["index"].item() == idx
|
||||
assert item["observation.state"][0].item() == float(expected_ep)
|
||||
assert item["observation.state"][1].item() == float(expected_frame)
|
||||
|
||||
del dataset_verify
|
||||
|
||||
# Phase 3: Resume recording - add more episodes
|
||||
dataset_resumed = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_resumed.meta.total_episodes == initial_episodes
|
||||
assert dataset_resumed.meta.total_frames == initial_episodes * frames_per_episode
|
||||
assert dataset_resumed.latest_episode is None # Not recording yet
|
||||
assert dataset_resumed.writer is None
|
||||
assert dataset_resumed.meta.writer is None
|
||||
|
||||
additional_episodes = 2
|
||||
for ep_idx in range(initial_episodes, initial_episodes + additional_episodes):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset_resumed.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([float(ep_idx), float(frame_idx)]),
|
||||
"action": torch.tensor([0.5, 0.5]),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset_resumed.save_episode()
|
||||
|
||||
total_episodes = initial_episodes + additional_episodes
|
||||
total_frames = total_episodes * frames_per_episode
|
||||
assert dataset_resumed.meta.total_episodes == total_episodes
|
||||
assert dataset_resumed.meta.total_frames == total_frames
|
||||
|
||||
dataset_resumed.finalize()
|
||||
del dataset_resumed
|
||||
|
||||
dataset_final = LeRobotDataset(initial_repo_id, root=initial_root, revision="v3.0")
|
||||
|
||||
assert dataset_final.meta.total_episodes == total_episodes
|
||||
assert dataset_final.meta.total_frames == total_frames
|
||||
assert len(dataset_final.hf_dataset) == total_frames
|
||||
|
||||
for idx in range(total_frames):
|
||||
item = dataset_final[idx]
|
||||
expected_ep = idx // frames_per_episode
|
||||
expected_frame = idx % frames_per_episode
|
||||
|
||||
assert item["episode_index"].item() == expected_ep, (
|
||||
f"Frame {idx}: wrong episode_index. Expected {expected_ep}, got {item['episode_index'].item()}"
|
||||
)
|
||||
assert item["frame_index"].item() == expected_frame, (
|
||||
f"Frame {idx}: wrong frame_index. Expected {expected_frame}, got {item['frame_index'].item()}"
|
||||
)
|
||||
assert item["index"].item() == idx, (
|
||||
f"Frame {idx}: wrong index. Expected {idx}, got {item['index'].item()}"
|
||||
)
|
||||
|
||||
# Verify data integrity
|
||||
assert item["observation.state"][0].item() == float(expected_ep), (
|
||||
f"Frame {idx}: wrong observation.state[0]. Expected {float(expected_ep)}, "
|
||||
f"got {item['observation.state'][0].item()}"
|
||||
)
|
||||
assert item["observation.state"][1].item() == float(expected_frame), (
|
||||
f"Frame {idx}: wrong observation.state[1]. Expected {float(expected_frame)}, "
|
||||
f"got {item['observation.state'][1].item()}"
|
||||
)
|
||||
|
||||
assert len(dataset_final.meta.episodes) == total_episodes
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_metadata = dataset_final.meta.episodes[ep_idx]
|
||||
assert ep_metadata["episode_index"] == ep_idx
|
||||
assert ep_metadata["length"] == frames_per_episode
|
||||
assert ep_metadata["tasks"] == [f"task_{ep_idx}"]
|
||||
|
||||
expected_from = ep_idx * frames_per_episode
|
||||
expected_to = (ep_idx + 1) * frames_per_episode
|
||||
assert ep_metadata["dataset_from_index"] == expected_from
|
||||
assert ep_metadata["dataset_to_index"] == expected_to
|
||||
|
||||
Reference in New Issue
Block a user