Speedup data loading
This commit is contained in:
@@ -736,7 +736,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
@@ -745,7 +745,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||||
for key, q_idx in query_indices.items()
|
for key, q_idx in query_indices.items()
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user