fix: making sure to stream within-episode frames

This commit is contained in:
fracapuano
2025-05-31 20:07:16 +02:00
parent 617bebb617
commit 3bf63e5518
2 changed files with 103 additions and 28 deletions

View File

@@ -164,6 +164,7 @@ def _profile_iteration(dataset, num_samples, stats_file_path):
profiler.add_function(dataset.__iter__)
profiler.add_function(dataset.make_frame)
profiler.add_function(dataset._make_backtrackable_dataset)
profiler.add_function(dataset._get_delta_frames)
# Profile the iteration

View File

@@ -222,43 +222,126 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
history, lookahead = self._get_window_steps()
return Backtrackable(dataset, history=history, lookahead=lookahead)
def _make_timestamps_from_indices(
self, start_ts: float, indices: dict[str, list[int]] | None = None
) -> dict[str, list[float]]:
if indices is not None:
return {
key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist()
for key in self.delta_timestamps
}
else:
return dict.fromkeys(self.delta_timestamps, start_ts)
def _make_padding_camera_frame(self, camera_key: str):
"""Variable-shape padding frame for given camera keys, given in (C, H, W)"""
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
def _pad_retrieved_video_frames(
self,
video_frames: dict[str, torch.Tensor],
query_timestamps: dict[str, list[float]],
original_timestamps: dict[str, list[float]],
) -> tuple[dict[str, torch.Tensor], dict[str, torch.BoolTensor]]:
padded_video_frames = {}
padding_mask = {}
for video_key, timestamps in original_timestamps.items():
if video_key not in video_frames:
continue # only padding on video keys that are available
frames = []
mask = []
padding_frame = self._make_padding_camera_frame(video_key)
for ts in timestamps:
if ts in query_timestamps[video_key]:
idx = query_timestamps[video_key].index(ts)
frames.append(video_frames[video_key][idx, :])
mask.append(False)
else:
frames.append(padding_frame)
mask.append(True)
padded_video_frames[video_key] = torch.stack(frames)
padding_mask[f"{video_key}.pad_masking"] = torch.BoolTensor(mask)
return padded_video_frames, padding_mask
@profile
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
updates = [] # list of updates to apply to the item
# Get episode index from the item
ep_idx = item["episode_index"]
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
current_ts = item["index"] / self.fps
episode_boundaries_ts = {
key: (
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
)
for key in self.meta.video_keys
}
# Apply delta querying logic if necessary
if self.delta_indices is not None:
query_result, padding = self._get_delta_frames(dataset_iterator, item)
item = {**item, **query_result, **padding}
updates.append(query_result)
updates.append(padding)
# Load video frames, when needed
if len(self.meta.video_keys) > 0:
current_ts = item["timestamp"]
query_indices = self.delta_indices if self.delta_indices is not None else None
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
# Some timestamps might not result available considering the episode's boundaries
query_timestamps = self._get_query_timestamps(
current_ts, self.delta_indices, episode_boundaries_ts
)
video_frames = self._query_videos(query_timestamps, ep_idx)
item = {**video_frames, **item}
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
# We always return the same number of frames. Unavailable frames are padded.
padded_video_frames, padding_mask = self._pad_retrieved_video_frames(
video_frames, query_timestamps, original_timestamps
)
yield item
updates.append(video_frames)
updates.append(padded_video_frames)
updates.append(padding_mask)
result = item.copy()
for update in updates:
result.update(update)
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
yield result
def _get_query_timestamps(
self,
current_ts: float,
query_indices: dict[str, list[int]] | None = None,
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
) -> dict[str, list[float]]:
query_timestamps = {}
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
for key in self.meta.video_keys:
if query_indices is not None and key in query_indices:
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
# never query for negative timestamps!
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
timestamps = keys_to_timestamps[key]
# Filter out timesteps outside of episode boundaries
query_timestamps[key] = [
ts
for ts in timestamps
if episode_boundaries_ts[key][0] <= ts <= episode_boundaries_ts[key][1]
]
if len(query_timestamps[key]) == 0:
raise ValueError(f"No valid timestamps found for key {key} with {query_indices[key]}")
else:
query_timestamps[key] = [current_ts]
@@ -271,29 +354,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
the main process and a subprocess fails to access it.
"""
item = {}
for vid_key, query_ts in query_timestamps.items():
for video_key, query_ts in query_timestamps.items():
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}"
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
frames = decode_video_frames_torchcodec(
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
)
item[vid_key] = frames.squeeze(0)
item[video_key] = frames
return item
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_ahead(n_steps):
return dataset_iterator.peek_ahead(n_steps)
else:
pass
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
if dataset_iterator.can_go_back(n_steps):
return dataset_iterator.peek_back(n_steps)
else:
pass
def _make_padding_frame(self, key: str) -> tuple[torch.Tensor, bool]:
return torch.zeros(self.meta.info["features"][key]["shape"]), True
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
# TODO(fracapuano): Modularize this function, refactor the code
"""Get frames with delta offsets using the backtrackable iterator.
Args:
@@ -419,7 +495,5 @@ if __name__ == "__main__":
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
for i, frame in tqdm(enumerate(dataset)):
print(frame)
if i > 1000: # only stream first 10 frames
break
for _i, _frame in tqdm(enumerate(dataset)):
pass