forked from tangger/lerobot
Add manage_dataset
This commit is contained in:
@@ -873,7 +873,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return video_paths
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
def consolidate(
|
||||
self,
|
||||
run_compute_stats: bool = True,
|
||||
keep_image_files: bool = False,
|
||||
batch_size: int = 8,
|
||||
num_workers: int = 8,
|
||||
) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
@@ -896,7 +902,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if run_compute_stats:
|
||||
self.stop_image_writer()
|
||||
# TODO(aliberts): refactor stats in save_episodes
|
||||
self.meta.stats = compute_stats(self)
|
||||
self.meta.stats = compute_stats(self, batch_size=batch_size, num_workers=num_workers)
|
||||
serialized_stats = serialize_dict(self.meta.stats)
|
||||
write_json(serialized_stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
|
||||
Reference in New Issue
Block a user