fix: making sure to stream within-episode frames
This commit is contained in:
@@ -164,6 +164,7 @@ def _profile_iteration(dataset, num_samples, stats_file_path):
|
|||||||
profiler.add_function(dataset.__iter__)
|
profiler.add_function(dataset.__iter__)
|
||||||
profiler.add_function(dataset.make_frame)
|
profiler.add_function(dataset.make_frame)
|
||||||
profiler.add_function(dataset._make_backtrackable_dataset)
|
profiler.add_function(dataset._make_backtrackable_dataset)
|
||||||
|
profiler.add_function(dataset._get_delta_frames)
|
||||||
|
|
||||||
# Profile the iteration
|
# Profile the iteration
|
||||||
|
|
||||||
|
|||||||
@@ -222,43 +222,126 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
history, lookahead = self._get_window_steps()
|
history, lookahead = self._get_window_steps()
|
||||||
return Backtrackable(dataset, history=history, lookahead=lookahead)
|
return Backtrackable(dataset, history=history, lookahead=lookahead)
|
||||||
|
|
||||||
|
def _make_timestamps_from_indices(
|
||||||
|
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
||||||
|
) -> dict[str, list[float]]:
|
||||||
|
if indices is not None:
|
||||||
|
return {
|
||||||
|
key: (start_ts + torch.tensor(indices[key]) / self.fps).tolist()
|
||||||
|
for key in self.delta_timestamps
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return dict.fromkeys(self.delta_timestamps, start_ts)
|
||||||
|
|
||||||
|
def _make_padding_camera_frame(self, camera_key: str):
|
||||||
|
"""Variable-shape padding frame for given camera keys, given in (C, H, W)"""
|
||||||
|
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
||||||
|
|
||||||
|
def _pad_retrieved_video_frames(
|
||||||
|
self,
|
||||||
|
video_frames: dict[str, torch.Tensor],
|
||||||
|
query_timestamps: dict[str, list[float]],
|
||||||
|
original_timestamps: dict[str, list[float]],
|
||||||
|
) -> tuple[dict[str, torch.Tensor], dict[str, torch.BoolTensor]]:
|
||||||
|
padded_video_frames = {}
|
||||||
|
padding_mask = {}
|
||||||
|
|
||||||
|
for video_key, timestamps in original_timestamps.items():
|
||||||
|
if video_key not in video_frames:
|
||||||
|
continue # only padding on video keys that are available
|
||||||
|
frames = []
|
||||||
|
mask = []
|
||||||
|
padding_frame = self._make_padding_camera_frame(video_key)
|
||||||
|
for ts in timestamps:
|
||||||
|
if ts in query_timestamps[video_key]:
|
||||||
|
idx = query_timestamps[video_key].index(ts)
|
||||||
|
frames.append(video_frames[video_key][idx, :])
|
||||||
|
mask.append(False)
|
||||||
|
else:
|
||||||
|
frames.append(padding_frame)
|
||||||
|
mask.append(True)
|
||||||
|
|
||||||
|
padded_video_frames[video_key] = torch.stack(frames)
|
||||||
|
padding_mask[f"{video_key}.pad_masking"] = torch.BoolTensor(mask)
|
||||||
|
|
||||||
|
return padded_video_frames, padding_mask
|
||||||
|
|
||||||
@profile
|
@profile
|
||||||
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
||||||
"""Makes a frame starting from a dataset iterator"""
|
"""Makes a frame starting from a dataset iterator"""
|
||||||
item = next(dataset_iterator)
|
item = next(dataset_iterator)
|
||||||
item = item_to_torch(item)
|
item = item_to_torch(item)
|
||||||
|
|
||||||
|
updates = [] # list of updates to apply to the item
|
||||||
|
|
||||||
# Get episode index from the item
|
# Get episode index from the item
|
||||||
ep_idx = item["episode_index"]
|
ep_idx = item["episode_index"]
|
||||||
|
|
||||||
|
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
||||||
|
current_ts = item["index"] / self.fps
|
||||||
|
|
||||||
|
episode_boundaries_ts = {
|
||||||
|
key: (
|
||||||
|
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
||||||
|
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
||||||
|
)
|
||||||
|
for key in self.meta.video_keys
|
||||||
|
}
|
||||||
|
|
||||||
# Apply delta querying logic if necessary
|
# Apply delta querying logic if necessary
|
||||||
if self.delta_indices is not None:
|
if self.delta_indices is not None:
|
||||||
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
||||||
item = {**item, **query_result, **padding}
|
updates.append(query_result)
|
||||||
|
updates.append(padding)
|
||||||
|
|
||||||
# Load video frames, when needed
|
# Load video frames, when needed
|
||||||
if len(self.meta.video_keys) > 0:
|
if len(self.meta.video_keys) > 0:
|
||||||
current_ts = item["timestamp"]
|
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
||||||
query_indices = self.delta_indices if self.delta_indices is not None else None
|
|
||||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
# Some timestamps might not result available considering the episode's boundaries
|
||||||
|
query_timestamps = self._get_query_timestamps(
|
||||||
|
current_ts, self.delta_indices, episode_boundaries_ts
|
||||||
|
)
|
||||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||||
item = {**video_frames, **item}
|
|
||||||
|
|
||||||
item["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
# We always return the same number of frames. Unavailable frames are padded.
|
||||||
|
padded_video_frames, padding_mask = self._pad_retrieved_video_frames(
|
||||||
|
video_frames, query_timestamps, original_timestamps
|
||||||
|
)
|
||||||
|
|
||||||
yield item
|
updates.append(video_frames)
|
||||||
|
updates.append(padded_video_frames)
|
||||||
|
updates.append(padding_mask)
|
||||||
|
|
||||||
|
result = item.copy()
|
||||||
|
for update in updates:
|
||||||
|
result.update(update)
|
||||||
|
|
||||||
|
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
def _get_query_timestamps(
|
def _get_query_timestamps(
|
||||||
self,
|
self,
|
||||||
current_ts: float,
|
current_ts: float,
|
||||||
query_indices: dict[str, list[int]] | None = None,
|
query_indices: dict[str, list[int]] | None = None,
|
||||||
|
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, list[float]]:
|
||||||
query_timestamps = {}
|
query_timestamps = {}
|
||||||
|
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = current_ts + torch.tensor(query_indices[key]) / self.fps
|
timestamps = keys_to_timestamps[key]
|
||||||
# never query for negative timestamps!
|
# Filter out timesteps outside of episode boundaries
|
||||||
query_timestamps[key] = list(filter(lambda x: x >= 0, timestamps.tolist()))
|
query_timestamps[key] = [
|
||||||
|
ts
|
||||||
|
for ts in timestamps
|
||||||
|
if episode_boundaries_ts[key][0] <= ts <= episode_boundaries_ts[key][1]
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(query_timestamps[key]) == 0:
|
||||||
|
raise ValueError(f"No valid timestamps found for key {key} with {query_indices[key]}")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
@@ -271,29 +354,22 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
the main process and a subprocess fails to access it.
|
the main process and a subprocess fails to access it.
|
||||||
"""
|
"""
|
||||||
item = {}
|
item = {}
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
for video_key, query_ts in query_timestamps.items():
|
||||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, vid_key)}"
|
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||||
frames = decode_video_frames_torchcodec(
|
frames = decode_video_frames_torchcodec(
|
||||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
||||||
)
|
)
|
||||||
item[vid_key] = frames.squeeze(0)
|
|
||||||
|
item[video_key] = frames
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def _get_future_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
def _make_padding_frame(self, key: str) -> tuple[torch.Tensor, bool]:
|
||||||
if dataset_iterator.can_go_ahead(n_steps):
|
return torch.zeros(self.meta.info["features"][key]["shape"]), True
|
||||||
return dataset_iterator.peek_ahead(n_steps)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_previous_frame(self, dataset_iterator: Backtrackable, n_steps: int):
|
|
||||||
if dataset_iterator.can_go_back(n_steps):
|
|
||||||
return dataset_iterator.peek_back(n_steps)
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
||||||
|
# TODO(fracapuano): Modularize this function, refactor the code
|
||||||
"""Get frames with delta offsets using the backtrackable iterator.
|
"""Get frames with delta offsets using the backtrackable iterator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -419,7 +495,5 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
dataset = StreamingLeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||||
|
|
||||||
for i, frame in tqdm(enumerate(dataset)):
|
for _i, _frame in tqdm(enumerate(dataset)):
|
||||||
print(frame)
|
pass
|
||||||
if i > 1000: # only stream first 10 frames
|
|
||||||
break
|
|
||||||
|
|||||||
Reference in New Issue
Block a user