Fix: check_cached_episodes doesn't check if the requested episode video were downloaded (#2296)
* In `check_cached_episodes_sufficient` check whether all the requested video files are downloaded * optimize loop over the video paths * revert example num_workers * Apply suggestion from @Copilot Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> * set num_workers to zero in example * style nit * reintroduce copilot optim --------- Signed-off-by: Michel Aractingi <michel.aractingi@huggingface.co> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
|||||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||||
|
|
||||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
if __name__ == "__main__":
|
||||||
# PyTorch datasets.
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataset,
|
||||||
dataset,
|
num_workers=4,
|
||||||
num_workers=4,
|
batch_size=32,
|
||||||
batch_size=32,
|
shuffle=True,
|
||||||
shuffle=True,
|
)
|
||||||
)
|
for batch in dataloader:
|
||||||
|
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||||
for batch in dataloader:
|
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
break
|
||||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
|
||||||
break
|
|
||||||
|
|||||||
@@ -837,7 +837,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
def _check_cached_episodes_sufficient(self) -> bool:
|
def _check_cached_episodes_sufficient(self) -> bool:
|
||||||
"""Check if the cached dataset contains all requested episodes."""
|
"""Check if the cached dataset contains all requested episodes and their video files."""
|
||||||
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
if self.hf_dataset is None or len(self.hf_dataset) == 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -856,7 +856,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
requested_episodes = set(self.episodes)
|
requested_episodes = set(self.episodes)
|
||||||
|
|
||||||
# Check if all requested episodes are available in cached data
|
# Check if all requested episodes are available in cached data
|
||||||
return requested_episodes.issubset(available_episodes)
|
if not requested_episodes.issubset(available_episodes):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if all required video files exist
|
||||||
|
if len(self.meta.video_keys) > 0:
|
||||||
|
for ep_idx in requested_episodes:
|
||||||
|
for vid_key in self.meta.video_keys:
|
||||||
|
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||||
|
if not video_path.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def create_hf_dataset(self) -> datasets.Dataset:
|
def create_hf_dataset(self) -> datasets.Dataset:
|
||||||
features = get_hf_features_from_features(self.features)
|
features = get_hf_features_from_features(self.features)
|
||||||
|
|||||||
Reference in New Issue
Block a user