forked from tangger/lerobot
improvements from JClinton to speed up loading offline data
This commit is contained in:
@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
# Step 1: Combine all unique indices
|
||||
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
|
||||
|
||||
# Step 2: Select all required data at once
|
||||
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
|
||||
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
|
||||
|
||||
# Step 3: Map original indices to their positions in the selected dataset
|
||||
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
|
||||
|
||||
# Step 4: Build the result for each key
|
||||
results = {}
|
||||
for key, q_indices in query_indices.items():
|
||||
if key not in self.meta.video_keys:
|
||||
mapped_indices = [index_map[idx] for idx in q_indices]
|
||||
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
|
||||
|
||||
return results
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
|
||||
Reference in New Issue
Block a user