Fix: check_cached_episodes doesn't check if the requested episode video were downloaded (#2296)

* In `check_cached_episodes_sufficient` check whether all the requested video files are downloaded

* optimize loop over the video paths

* revert example num_workers

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>

* set num_workers to zero in example

* style nit

* reintroduce copilot optim

---------

Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Michel Aractingi
2025-10-23 17:34:03 +02:00
committed by GitHub
parent df71f3ce24
commit 76a425c600
2 changed files with 25 additions and 16 deletions

View File

@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
print(f"{dataset[0]['observation.state'].shape=}") # (6, c) print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
print(f"{dataset[0]['action'].shape=}\n") # (64, c) print(f"{dataset[0]['action'].shape=}\n") # (64, c)
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just if __name__ == "__main__":
# PyTorch datasets. dataloader = torch.utils.data.DataLoader(
dataloader = torch.utils.data.DataLoader( dataset,
dataset, num_workers=4,
num_workers=4, batch_size=32,
batch_size=32, shuffle=True,
shuffle=True, )
) for batch in dataloader:
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
for batch in dataloader: print(f"{batch['observation.state'].shape=}") # (32, 6, c)
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w) print(f"{batch['action'].shape=}") # (32, 64, c)
print(f"{batch['observation.state'].shape=}") # (32, 6, c) break
print(f"{batch['action'].shape=}") # (32, 64, c)
break

View File

@@ -837,7 +837,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
return hf_dataset return hf_dataset
def _check_cached_episodes_sufficient(self) -> bool: def _check_cached_episodes_sufficient(self) -> bool:
"""Check if the cached dataset contains all requested episodes.""" """Check if the cached dataset contains all requested episodes and their video files."""
if self.hf_dataset is None or len(self.hf_dataset) == 0: if self.hf_dataset is None or len(self.hf_dataset) == 0:
return False return False
@@ -856,7 +856,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
requested_episodes = set(self.episodes) requested_episodes = set(self.episodes)
# Check if all requested episodes are available in cached data # Check if all requested episodes are available in cached data
return requested_episodes.issubset(available_episodes) if not requested_episodes.issubset(available_episodes):
return False
# Check if all required video files exist
if len(self.meta.video_keys) > 0:
for ep_idx in requested_episodes:
for vid_key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
if not video_path.exists():
return False
return True
def create_hf_dataset(self) -> datasets.Dataset: def create_hf_dataset(self) -> datasets.Dataset:
features = get_hf_features_from_features(self.features) features = get_hf_features_from_features(self.features)