From 3bf63e55184d46c3198c2bb0be8cd61e0f9f6ae2 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Sat, 31 May 2025 20:07:16 +0200 Subject: [PATCH] fix: making sure to stream within-episode frames --- .../datasets/profile_streaming_dataset.py | 1 + lerobot/common/datasets/streaming_dataset.py | 130 ++++++++++++++---- 2 files changed, 103 insertions(+), 28 deletions(-) diff --git a/lerobot/common/datasets/profile_streaming_dataset.py b/lerobot/common/datasets/profile_streaming_dataset.py index a9d02c4c9..7ac6e826b 100644 --- a/lerobot/common/datasets/profile_streaming_dataset.py +++ b/lerobot/common/datasets/profile_streaming_dataset.py @@ -164,6 +164,7 @@ def _profile_iteration(dataset, num_samples, stats_file_path): profiler.add_function(dataset.__iter__) profiler.add_function(dataset.make_frame) profiler.add_function(dataset._make_backtrackable_dataset) + profiler.add_function(dataset._get_delta_frames) # Profile the iteration diff --git a/lerobot/common/datasets/streaming_dataset.py b/lerobot/common/datasets/streaming_dataset.py index 743579130..2de02eebd 100644 --- a/lerobot/common/datasets/streaming_dataset.py +++ b/lerobot/common/datasets/streaming_dataset.py @@ -222,43 +222,126 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): history, lookahead = self._get_window_steps() return Backtrackable(dataset, history=history, lookahead=lookahead) + def _make_timestamps_from_indices( + self, start_ts: float, indices: dict[str, list[int]] | None = None + ) -> dict[str, list[float]]: + if indices is not None: + return { + key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist() + for key in self.delta_timestamps + } + else: + return dict.fromkeys(self.delta_timestamps, start_ts) + + def _make_padding_camera_frame(self, camera_key: str): + """Variable-shape padding frame for given camera keys, given in (C, H, W)""" + return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1) + + def _pad_retrieved_video_frames( + self, + video_frames: dict[str, torch.Tensor], + query_timestamps: dict[str, list[float]], + original_timestamps: dict[str, list[float]], + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.BoolTensor]]: + padded_video_frames = {} + padding_mask = {} + + for video_key, timestamps in original_timestamps.items(): + if video_key not in video_frames: + continue # only padding on video keys that are available + frames = [] + mask = [] + padding_frame = self._make_padding_camera_frame(video_key) + for ts in timestamps: + if ts in query_timestamps[video_key]: + idx = query_timestamps[video_key].index(ts) + frames.append(video_frames[video_key][idx, :]) + mask.append(False) + else: + frames.append(padding_frame) + mask.append(True) + + padded_video_frames[video_key] = torch.stack(frames) + padding_mask[f"{video_key}.pad_masking"] = torch.BoolTensor(mask) + + return padded_video_frames, padding_mask + @profile def make_frame(self, dataset_iterator: Backtrackable) -> Generator: """Makes a frame starting from a dataset iterator""" item = next(dataset_iterator) item = item_to_torch(item) + updates = [] # list of updates to apply to the item + # Get episode index from the item ep_idx = item["episode_index"] + # "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps) + current_ts = item["index"] / self.fps + + episode_boundaries_ts = { + key: ( + self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"], + self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"], + ) + for key in self.meta.video_keys + } + # Apply delta querying logic if necessary if self.delta_indices is not None: query_result, padding = self._get_delta_frames(dataset_iterator, item) - item = {**item, **query_result, **padding} + updates.append(query_result) + updates.append(padding) # Load video frames, when needed if len(self.meta.video_keys) > 0: - current_ts = item["timestamp"] - query_indices = self.delta_indices if self.delta_indices is not None else None - query_timestamps = self._get_query_timestamps(current_ts, query_indices) + original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices) + + # Some timestamps might not result available considering the episode's boundaries + query_timestamps = self._get_query_timestamps( + current_ts, self.delta_indices, episode_boundaries_ts + ) video_frames = self._query_videos(query_timestamps, ep_idx) - item = {**video_frames, **item} - item["task"] = self.meta.tasks.iloc[item["task_index"]].name + # We always return the same number of frames. Unavailable frames are padded. + padded_video_frames, padding_mask = self._pad_retrieved_video_frames( + video_frames, query_timestamps, original_timestamps + ) - yield item + updates.append(video_frames) + updates.append(padded_video_frames) + updates.append(padding_mask) + + result = item.copy() + for update in updates: + result.update(update) + + result["task"] = self.meta.tasks.iloc[item["task_index"]].name + + yield result def _get_query_timestamps( self, current_ts: float, query_indices: dict[str, list[int]] | None = None, + episode_boundaries_ts: dict[str, tuple[float, float]] | None = None, ) -> dict[str, list[float]]: query_timestamps = {} + keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices) for key in self.meta.video_keys: if query_indices is not None and key in query_indices: - timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps - # never query for negative timestamps! - query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist())) + timestamps = keys_to_timestamps[key] + # Filter out timesteps outside of episode boundaries + query_timestamps[key] = [ + ts + for ts in timestamps + if episode_boundaries_ts[key][0] <= ts <= episode_boundaries_ts[key][1] + ] + + if len(query_timestamps[key]) == 0: + raise ValueError(f"No valid timestamps found for key {key} with {query_indices[key]}") + else: query_timestamps[key] = [current_ts] @@ -271,29 +354,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): the main process and a subprocess fails to access it. """ item = {} - for vid_key, query_ts in query_timestamps.items(): + for video_key, query_ts in query_timestamps.items(): root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root - video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}" + video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}" frames = decode_video_frames_torchcodec( video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache ) - item[vid_key] = frames.squeeze(0) + + item[video_key] = frames return item - def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int): - if dataset_iterator.can_go_ahead(n_steps): - return dataset_iterator.peek_ahead(n_steps) - else: - pass - - def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int): - if dataset_iterator.can_go_back(n_steps): - return dataset_iterator.peek_back(n_steps) - else: - pass + def _make_padding_frame(self, key: str) -> tuple[torch.Tensor, bool]: + return torch.zeros(self.meta.info["features"][key]["shape"]), True def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict): + # TODO(fracapuano): Modularize this function, refactor the code """Get frames with delta offsets using the backtrackable iterator. Args: @@ -419,7 +495,5 @@ if __name__ == "__main__": dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps) - for i, frame in tqdm(enumerate(dataset)): - print(frame) - if i > 1000: # only stream first 10 frames - break + for _i, _frame in tqdm(enumerate(dataset)): + pass