fix: optimize delta-indices construction. No point in fetching after iteration failure in same direction

This commit is contained in:
fracapuano
2025-05-31 20:08:16 +02:00
parent 3bf63e5518
commit b54be4c23a

View File

@@ -392,55 +392,75 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
target_frames = []
is_pad = []
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
delta_results = {}
# Separate and sort deltas by difficulty (easier operations first)
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
zero_deltas = [d for d in delta_indices if d == 0]
# Process zero deltas (current frame)
for delta in zero_deltas:
delta_results[delta] = (
current_item[key],
False,
) # unsqueeze to add batch dimension for stacking
# Process negative deltas in order of increasing difficulty
lookback_failed = False
for delta in negative_deltas:
if lookback_failed:
delta_results[delta] = self._make_padding_frame(key)
continue
try:
steps_back = abs(delta)
if dataset_iterator.can_peek_back(steps_back):
past_item = dataset_iterator.peek_back(steps_back)
past_item = item_to_torch(past_item)
if past_item["episode_index"] == current_episode_idx:
delta_results[delta] = (past_item[key], False)
else:
raise LookBackError("Retrieved frame is from different episode!")
else:
raise LookBackError("Cannot go back further than the history buffer!")
except LookBackError:
delta_results[delta] = self._make_padding_frame(key)
lookback_failed = True # All subsequent negative deltas will also fail
# Process positive deltas in order of increasing difficulty
lookahead_failed = False
for delta in positive_deltas:
if lookahead_failed:
delta_results[delta] = self._make_padding_frame(key)
continue
try:
if dataset_iterator.can_peek_ahead(delta):
future_item = dataset_iterator.peek_ahead(delta)
future_item = item_to_torch(future_item)
if future_item["episode_index"] == current_episode_idx:
delta_results[delta] = (future_item[key], False)
else:
raise LookAheadError("Retrieved frame is from different episode!")
else:
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
except LookAheadError:
delta_results[delta] = self._make_padding_frame(key)
lookahead_failed = True # All subsequent positive deltas will also fail
# Reconstruct original order for stacking
for delta in delta_indices:
if delta == 0:
# Current frame
target_frames.append(current_item[key])
is_pad.append(False)
frame, is_padded = delta_results[delta]
elif delta < 0:
# Past frame. Use backtrackable iterator, looking back delta steps
try:
steps_back = abs(delta)
if dataset_iterator.can_peek_back(steps_back):
past_item = dataset_iterator.peek_back(steps_back)
past_item = item_to_torch(past_item)
# Check if it's from the same episode
if past_item["episode_index"] == current_episode_idx:
target_frames.append(past_item[key])
is_pad.append(False)
else:
raise LookBackError("Retrieved frame is from different episode!")
else:
raise LookBackError("Cannot go back further than the history buffer!")
except LookBackError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
elif delta > 0:
# Future frame - read ahead from the iterator
try:
if dataset_iterator.can_peek_ahead(delta):
future_item = dataset_iterator.peek_ahead(delta)
future_item = item_to_torch(future_item)
# Check if it's from the same episode
if future_item["episode_index"] == current_episode_idx:
target_frames.append(future_item[key])
is_pad.append(False)
else:
raise LookAheadError("Retrieved frame is from different episode!")
else:
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
except LookAheadError:
target_frames.append(torch.zeros_like(current_item[key]))
is_pad.append(True)
# add batch dimension for stacking
target_frames.append(frame) # frame.unsqueeze(0))
is_pad.append(is_padded)
# Stack frames and add to results
if target_frames: