forked from tangger/lerobot
fix: optimize delta-indices construction. No point in fetching after iteration failure in same direction
This commit is contained in:
@@ -392,55 +392,75 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
||||
target_frames = []
|
||||
is_pad = []
|
||||
|
||||
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
|
||||
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||
delta_results = {}
|
||||
|
||||
# Separate and sort deltas by difficulty (easier operations first)
|
||||
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||
zero_deltas = [d for d in delta_indices if d == 0]
|
||||
|
||||
# Process zero deltas (current frame)
|
||||
for delta in zero_deltas:
|
||||
delta_results[delta] = (
|
||||
current_item[key],
|
||||
False,
|
||||
) # unsqueeze to add batch dimension for stacking
|
||||
|
||||
# Process negative deltas in order of increasing difficulty
|
||||
lookback_failed = False
|
||||
for delta in negative_deltas:
|
||||
if lookback_failed:
|
||||
delta_results[delta] = self._make_padding_frame(key)
|
||||
continue
|
||||
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (past_item[key], False)
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
delta_results[delta] = self._make_padding_frame(key)
|
||||
lookback_failed = True # All subsequent negative deltas will also fail
|
||||
|
||||
# Process positive deltas in order of increasing difficulty
|
||||
lookahead_failed = False
|
||||
for delta in positive_deltas:
|
||||
if lookahead_failed:
|
||||
delta_results[delta] = self._make_padding_frame(key)
|
||||
continue
|
||||
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
delta_results[delta] = (future_item[key], False)
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
delta_results[delta] = self._make_padding_frame(key)
|
||||
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||
|
||||
# Reconstruct original order for stacking
|
||||
for delta in delta_indices:
|
||||
if delta == 0:
|
||||
# Current frame
|
||||
target_frames.append(current_item[key])
|
||||
is_pad.append(False)
|
||||
frame, is_padded = delta_results[delta]
|
||||
|
||||
elif delta < 0:
|
||||
# Past frame. Use backtrackable iterator, looking back delta steps
|
||||
try:
|
||||
steps_back = abs(delta)
|
||||
if dataset_iterator.can_peek_back(steps_back):
|
||||
past_item = dataset_iterator.peek_back(steps_back)
|
||||
past_item = item_to_torch(past_item)
|
||||
|
||||
# Check if it's from the same episode
|
||||
if past_item["episode_index"] == current_episode_idx:
|
||||
target_frames.append(past_item[key])
|
||||
is_pad.append(False)
|
||||
|
||||
else:
|
||||
raise LookBackError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookBackError("Cannot go back further than the history buffer!")
|
||||
|
||||
except LookBackError:
|
||||
target_frames.append(torch.zeros_like(current_item[key]))
|
||||
is_pad.append(True)
|
||||
|
||||
elif delta > 0:
|
||||
# Future frame - read ahead from the iterator
|
||||
try:
|
||||
if dataset_iterator.can_peek_ahead(delta):
|
||||
future_item = dataset_iterator.peek_ahead(delta)
|
||||
future_item = item_to_torch(future_item)
|
||||
|
||||
# Check if it's from the same episode
|
||||
if future_item["episode_index"] == current_episode_idx:
|
||||
target_frames.append(future_item[key])
|
||||
is_pad.append(False)
|
||||
|
||||
else:
|
||||
raise LookAheadError("Retrieved frame is from different episode!")
|
||||
else:
|
||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||
|
||||
except LookAheadError:
|
||||
target_frames.append(torch.zeros_like(current_item[key]))
|
||||
is_pad.append(True)
|
||||
# add batch dimension for stacking
|
||||
target_frames.append(frame) # frame.unsqueeze(0))
|
||||
is_pad.append(is_padded)
|
||||
|
||||
# Stack frames and add to results
|
||||
if target_frames:
|
||||
|
||||
Reference in New Issue
Block a user