diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 232558056..15cc5ead2 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset): return query_timestamps def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: - return { - key: torch.stack(self.hf_dataset.select(q_idx)[key]) - for key, q_idx in query_indices.items() - if key not in self.meta.video_keys - } + # Step 1: Combine all unique indices + all_indices = sorted({idx for indices in query_indices.values() for idx in indices}) + + # Step 2: Select all required data at once + selected_dataset = self.hf_dataset.select(all_indices).to_dict() + selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()} + + # Step 3: Map original indices to their positions in the selected dataset + index_map = {original_idx: i for i, original_idx in enumerate(all_indices)} + + # Step 4: Build the result for each key + results = {} + for key, q_indices in query_indices.items(): + if key not in self.meta.video_keys: + mapped_indices = [index_map[idx] for idx in q_indices] + results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices]) + + return results def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function