improvements from JClinton to speed up loading offline data

This commit is contained in:
Ke-Wang1017
2025-01-06 09:50:08 +00:00
parent db3925df28
commit 8b70b129dc

View File

@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
return query_timestamps
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
return {
key: torch.stack(self.hf_dataset.select(q_idx)[key])
for key, q_idx in query_indices.items()
if key not in self.meta.video_keys
}
# Step 1: Combine all unique indices
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
# Step 2: Select all required data at once
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
# Step 3: Map original indices to their positions in the selected dataset
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
# Step 4: Build the result for each key
results = {}
for key, q_indices in query_indices.items():
if key not in self.meta.video_keys:
mapped_indices = [index_map[idx] for idx in q_indices]
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
return results
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function