forked from tangger/lerobot
fix: making sure to stream within-episode frames
This commit is contained in:
@@ -164,6 +164,7 @@ def _profile_iteration(dataset, num_samples, stats_file_path):
|
||||
profiler.add_function(dataset.__iter__)
|
||||
profiler.add_function(dataset.make_frame)
|
||||
profiler.add_function(dataset._make_backtrackable_dataset)
|
||||
profiler.add_function(dataset._get_delta_frames)
|
||||
|
||||
# Profile the iteration
|
||||
|
||||
|
||||
@@ -222,43 +222,126 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
history, lookahead = self._get_window_steps()
|
||||
return Backtrackable(dataset, history=history, lookahead=lookahead)
|
||||
|
||||
def _make_timestamps_from_indices(
|
||||
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||
) -> dict[str, list[float]]:
|
||||
if indices is not None:
|
||||
return {
|
||||
key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist()
|
||||
for key in self.delta_timestamps
|
||||
}
|
||||
else:
|
||||
return dict.fromkeys(self.delta_timestamps, start_ts)
|
||||
|
||||
def _make_padding_camera_frame(self, camera_key: str):
|
||||
"""Variable-shape padding frame for given camera keys, given in (C, H, W)"""
|
||||
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||
|
||||
def _pad_retrieved_video_frames(
|
||||
self,
|
||||
video_frames: dict[str, torch.Tensor],
|
||||
query_timestamps: dict[str, list[float]],
|
||||
original_timestamps: dict[str, list[float]],
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, torch.BoolTensor]]:
|
||||
padded_video_frames = {}
|
||||
padding_mask = {}
|
||||
|
||||
for video_key, timestamps in original_timestamps.items():
|
||||
if video_key not in video_frames:
|
||||
continue # only padding on video keys that are available
|
||||
frames = []
|
||||
mask = []
|
||||
padding_frame = self._make_padding_camera_frame(video_key)
|
||||
for ts in timestamps:
|
||||
if ts in query_timestamps[video_key]:
|
||||
idx = query_timestamps[video_key].index(ts)
|
||||
frames.append(video_frames[video_key][idx, :])
|
||||
mask.append(False)
|
||||
else:
|
||||
frames.append(padding_frame)
|
||||
mask.append(True)
|
||||
|
||||
padded_video_frames[video_key] = torch.stack(frames)
|
||||
padding_mask[f"{video_key}.pad_masking"] = torch.BoolTensor(mask)
|
||||
|
||||
return padded_video_frames, padding_mask
|
||||
|
||||
@profile
|
||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||
"""Makes a frame starting from a dataset iterator"""
|
||||
item = next(dataset_iterator)
|
||||
item = item_to_torch(item)
|
||||
|
||||
updates = [] # list of updates to apply to the item
|
||||
|
||||
# Get episode index from the item
|
||||
ep_idx = item["episode_index"]
|
||||
|
||||
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||
current_ts = item["index"] / self.fps
|
||||
|
||||
episode_boundaries_ts = {
|
||||
key: (
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||
)
|
||||
for key in self.meta.video_keys
|
||||
}
|
||||
|
||||
# Apply delta querying logic if necessary
|
||||
if self.delta_indices is not None:
|
||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||
item = {**item, **query_result, **padding}
|
||||
updates.append(query_result)
|
||||
updates.append(padding)
|
||||
|
||||
# Load video frames, when needed
|
||||
if len(self.meta.video_keys) > 0:
|
||||
current_ts = item["timestamp"]
|
||||
query_indices = self.delta_indices if self.delta_indices is not None else None
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||
|
||||
# Some timestamps might not result available considering the episode's boundaries
|
||||
query_timestamps = self._get_query_timestamps(
|
||||
current_ts, self.delta_indices, episode_boundaries_ts
|
||||
)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
|
||||
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
# We always return the same number of frames. Unavailable frames are padded.
|
||||
padded_video_frames, padding_mask = self._pad_retrieved_video_frames(
|
||||
video_frames, query_timestamps, original_timestamps
|
||||
)
|
||||
|
||||
yield item
|
||||
updates.append(video_frames)
|
||||
updates.append(padded_video_frames)
|
||||
updates.append(padding_mask)
|
||||
|
||||
result = item.copy()
|
||||
for update in updates:
|
||||
result.update(update)
|
||||
|
||||
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||
|
||||
yield result
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
|
||||
# never query for negative timestamps!
|
||||
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
|
||||
timestamps = keys_to_timestamps[key]
|
||||
# Filter out timesteps outside of episode boundaries
|
||||
query_timestamps[key] = [
|
||||
ts
|
||||
for ts in timestamps
|
||||
if episode_boundaries_ts[key][0] <= ts <= episode_boundaries_ts[key][1]
|
||||
]
|
||||
|
||||
if len(query_timestamps[key]) == 0:
|
||||
raise ValueError(f"No valid timestamps found for key {key} with {query_indices[key]}")
|
||||
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
@@ -271,29 +354,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
for video_key, query_ts in query_timestamps.items():
|
||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}"
|
||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||
frames = decode_video_frames_torchcodec(
|
||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
item[video_key] = frames
|
||||
|
||||
return item
|
||||
|
||||
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||
if dataset_iterator.can_go_ahead(n_steps):
|
||||
return dataset_iterator.peek_ahead(n_steps)
|
||||
else:
|
||||
pass
|
||||
|
||||
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
||||
if dataset_iterator.can_go_back(n_steps):
|
||||
return dataset_iterator.peek_back(n_steps)
|
||||
else:
|
||||
pass
|
||||
def _make_padding_frame(self, key: str) -> tuple[torch.Tensor, bool]:
|
||||
return torch.zeros(self.meta.info["features"][key]["shape"]), True
|
||||
|
||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||
# TODO(fracapuano): Modularize this function, refactor the code
|
||||
"""Get frames with delta offsets using the backtrackable iterator.
|
||||
|
||||
Args:
|
||||
@@ -419,7 +495,5 @@ if __name__ == "__main__":
|
||||
|
||||
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
for i, frame in tqdm(enumerate(dataset)):
|
||||
print(frame)
|
||||
if i > 1000: # only stream first 10 frames
|
||||
break
|
||||
for _i, _frame in tqdm(enumerate(dataset)):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user