diff --git a/lerobot/common/datasets/streaming_dataset.py b/lerobot/common/datasets/streaming_dataset.py index 2de02eebd..df3a2c4ac 100644 --- a/lerobot/common/datasets/streaming_dataset.py +++ b/lerobot/common/datasets/streaming_dataset.py @@ -392,55 +392,75 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset): target_frames = [] is_pad = [] - # NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error? + # Create a results dictionary to store frames in processing order, then reconstruct original order for stacking + delta_results = {} + + # Separate and sort deltas by difficulty (easier operations first) + negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...] + positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...] + zero_deltas = [d for d in delta_indices if d == 0] + + # Process zero deltas (current frame) + for delta in zero_deltas: + delta_results[delta] = ( + current_item[key], + False, + ) # unsqueeze to add batch dimension for stacking + + # Process negative deltas in order of increasing difficulty + lookback_failed = False + for delta in negative_deltas: + if lookback_failed: + delta_results[delta] = self._make_padding_frame(key) + continue + + try: + steps_back = abs(delta) + if dataset_iterator.can_peek_back(steps_back): + past_item = dataset_iterator.peek_back(steps_back) + past_item = item_to_torch(past_item) + + if past_item["episode_index"] == current_episode_idx: + delta_results[delta] = (past_item[key], False) + else: + raise LookBackError("Retrieved frame is from different episode!") + else: + raise LookBackError("Cannot go back further than the history buffer!") + + except LookBackError: + delta_results[delta] = self._make_padding_frame(key) + lookback_failed = True # All subsequent negative deltas will also fail + + # Process positive deltas in order of increasing difficulty + lookahead_failed = False + for delta in positive_deltas: + if lookahead_failed: + delta_results[delta] = self._make_padding_frame(key) + continue + + try: + if dataset_iterator.can_peek_ahead(delta): + future_item = dataset_iterator.peek_ahead(delta) + future_item = item_to_torch(future_item) + + if future_item["episode_index"] == current_episode_idx: + delta_results[delta] = (future_item[key], False) + else: + raise LookAheadError("Retrieved frame is from different episode!") + else: + raise LookAheadError("Cannot go ahead further than the lookahead buffer!") + + except LookAheadError: + delta_results[delta] = self._make_padding_frame(key) + lookahead_failed = True # All subsequent positive deltas will also fail + + # Reconstruct original order for stacking for delta in delta_indices: - if delta == 0: - # Current frame - target_frames.append(current_item[key]) - is_pad.append(False) + frame, is_padded = delta_results[delta] - elif delta < 0: - # Past frame. Use backtrackable iterator, looking back delta steps - try: - steps_back = abs(delta) - if dataset_iterator.can_peek_back(steps_back): - past_item = dataset_iterator.peek_back(steps_back) - past_item = item_to_torch(past_item) - - # Check if it's from the same episode - if past_item["episode_index"] == current_episode_idx: - target_frames.append(past_item[key]) - is_pad.append(False) - - else: - raise LookBackError("Retrieved frame is from different episode!") - else: - raise LookBackError("Cannot go back further than the history buffer!") - - except LookBackError: - target_frames.append(torch.zeros_like(current_item[key])) - is_pad.append(True) - - elif delta > 0: - # Future frame - read ahead from the iterator - try: - if dataset_iterator.can_peek_ahead(delta): - future_item = dataset_iterator.peek_ahead(delta) - future_item = item_to_torch(future_item) - - # Check if it's from the same episode - if future_item["episode_index"] == current_episode_idx: - target_frames.append(future_item[key]) - is_pad.append(False) - - else: - raise LookAheadError("Retrieved frame is from different episode!") - else: - raise LookAheadError("Cannot go ahead further than the lookahead buffer!") - - except LookAheadError: - target_frames.append(torch.zeros_like(current_item[key])) - is_pad.append(True) + # add batch dimension for stacking + target_frames.append(frame) # frame.unsqueeze(0)) + is_pad.append(is_padded) # Stack frames and add to results if target_frames: