diff --git a/examples/dataset/load_lerobot_dataset.py b/examples/dataset/load_lerobot_dataset.py index a96c170c..a6916981 100644 --- a/examples/dataset/load_lerobot_dataset.py +++ b/examples/dataset/load_lerobot_dataset.py @@ -132,17 +132,15 @@ print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w) print(f"{dataset[0]['observation.state'].shape=}") # (6, c) print(f"{dataset[0]['action'].shape=}\n") # (64, c) -# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just -# PyTorch datasets. -dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=4, - batch_size=32, - shuffle=True, -) - -for batch in dataloader: - print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w) - print(f"{batch['observation.state'].shape=}") # (32, 6, c) - print(f"{batch['action'].shape=}") # (32, 64, c) - break +if __name__ == "__main__": + dataloader = torch.utils.data.DataLoader( + dataset, + num_workers=4, + batch_size=32, + shuffle=True, + ) + for batch in dataloader: + print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w) + print(f"{batch['observation.state'].shape=}") # (32, 6, c) + print(f"{batch['action'].shape=}") # (32, 64, c) + break diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6258312a..a6840891 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -837,7 +837,7 @@ class LeRobotDataset(torch.utils.data.Dataset): return hf_dataset def _check_cached_episodes_sufficient(self) -> bool: - """Check if the cached dataset contains all requested episodes.""" + """Check if the cached dataset contains all requested episodes and their video files.""" if self.hf_dataset is None or len(self.hf_dataset) == 0: return False @@ -856,7 +856,18 @@ class LeRobotDataset(torch.utils.data.Dataset): requested_episodes = set(self.episodes) # Check if all requested episodes are available in cached data - return requested_episodes.issubset(available_episodes) + if not requested_episodes.issubset(available_episodes): + return False + + # Check if all required video files exist + if len(self.meta.video_keys) > 0: + for ep_idx in requested_episodes: + for vid_key in self.meta.video_keys: + video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key) + if not video_path.exists(): + return False + + return True def create_hf_dataset(self) -> datasets.Dataset: features = get_hf_features_from_features(self.features)