improvements from JClinton to speed up loading offline data
This commit is contained in:
@@ -611,11 +611,24 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
return query_timestamps
|
return query_timestamps
|
||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
# Step 1: Combine all unique indices
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
all_indices = sorted({idx for indices in query_indices.values() for idx in indices})
|
||||||
for key, q_idx in query_indices.items()
|
|
||||||
if key not in self.meta.video_keys
|
# Step 2: Select all required data at once
|
||||||
}
|
selected_dataset = self.hf_dataset.select(all_indices).to_dict()
|
||||||
|
selected_dataset = {key: torch.tensor(values) for key, values in selected_dataset.items()}
|
||||||
|
|
||||||
|
# Step 3: Map original indices to their positions in the selected dataset
|
||||||
|
index_map = {original_idx: i for i, original_idx in enumerate(all_indices)}
|
||||||
|
|
||||||
|
# Step 4: Build the result for each key
|
||||||
|
results = {}
|
||||||
|
for key, q_indices in query_indices.items():
|
||||||
|
if key not in self.meta.video_keys:
|
||||||
|
mapped_indices = [index_map[idx] for idx in q_indices]
|
||||||
|
results[key] = torch.stack([selected_dataset[key][i] for i in mapped_indices])
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||||
|
|||||||
Reference in New Issue
Block a user