forked from tangger/lerobot
fix: optimize delta-indices construction. No point in fetching after iteration failure in same direction
This commit is contained in:
@@ -392,55 +392,75 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
target_frames = []
|
target_frames = []
|
||||||
is_pad = []
|
is_pad = []
|
||||||
|
|
||||||
# NOTE(fracapuano): Optimize this. What's the point in checking all deltas after first error?
|
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
||||||
|
delta_results = {}
|
||||||
|
|
||||||
|
# Separate and sort deltas by difficulty (easier operations first)
|
||||||
|
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
||||||
|
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
||||||
|
zero_deltas = [d for d in delta_indices if d == 0]
|
||||||
|
|
||||||
|
# Process zero deltas (current frame)
|
||||||
|
for delta in zero_deltas:
|
||||||
|
delta_results[delta] = (
|
||||||
|
current_item[key],
|
||||||
|
False,
|
||||||
|
) # unsqueeze to add batch dimension for stacking
|
||||||
|
|
||||||
|
# Process negative deltas in order of increasing difficulty
|
||||||
|
lookback_failed = False
|
||||||
|
for delta in negative_deltas:
|
||||||
|
if lookback_failed:
|
||||||
|
delta_results[delta] = self._make_padding_frame(key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
steps_back = abs(delta)
|
||||||
|
if dataset_iterator.can_peek_back(steps_back):
|
||||||
|
past_item = dataset_iterator.peek_back(steps_back)
|
||||||
|
past_item = item_to_torch(past_item)
|
||||||
|
|
||||||
|
if past_item["episode_index"] == current_episode_idx:
|
||||||
|
delta_results[delta] = (past_item[key], False)
|
||||||
|
else:
|
||||||
|
raise LookBackError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookBackError("Cannot go back further than the history buffer!")
|
||||||
|
|
||||||
|
except LookBackError:
|
||||||
|
delta_results[delta] = self._make_padding_frame(key)
|
||||||
|
lookback_failed = True # All subsequent negative deltas will also fail
|
||||||
|
|
||||||
|
# Process positive deltas in order of increasing difficulty
|
||||||
|
lookahead_failed = False
|
||||||
|
for delta in positive_deltas:
|
||||||
|
if lookahead_failed:
|
||||||
|
delta_results[delta] = self._make_padding_frame(key)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if dataset_iterator.can_peek_ahead(delta):
|
||||||
|
future_item = dataset_iterator.peek_ahead(delta)
|
||||||
|
future_item = item_to_torch(future_item)
|
||||||
|
|
||||||
|
if future_item["episode_index"] == current_episode_idx:
|
||||||
|
delta_results[delta] = (future_item[key], False)
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Retrieved frame is from different episode!")
|
||||||
|
else:
|
||||||
|
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
||||||
|
|
||||||
|
except LookAheadError:
|
||||||
|
delta_results[delta] = self._make_padding_frame(key)
|
||||||
|
lookahead_failed = True # All subsequent positive deltas will also fail
|
||||||
|
|
||||||
|
# Reconstruct original order for stacking
|
||||||
for delta in delta_indices:
|
for delta in delta_indices:
|
||||||
if delta == 0:
|
frame, is_padded = delta_results[delta]
|
||||||
# Current frame
|
|
||||||
target_frames.append(current_item[key])
|
|
||||||
is_pad.append(False)
|
|
||||||
|
|
||||||
elif delta < 0:
|
# add batch dimension for stacking
|
||||||
# Past frame. Use backtrackable iterator, looking back delta steps
|
target_frames.append(frame) # frame.unsqueeze(0))
|
||||||
try:
|
is_pad.append(is_padded)
|
||||||
steps_back = abs(delta)
|
|
||||||
if dataset_iterator.can_peek_back(steps_back):
|
|
||||||
past_item = dataset_iterator.peek_back(steps_back)
|
|
||||||
past_item = item_to_torch(past_item)
|
|
||||||
|
|
||||||
# Check if it's from the same episode
|
|
||||||
if past_item["episode_index"] == current_episode_idx:
|
|
||||||
target_frames.append(past_item[key])
|
|
||||||
is_pad.append(False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise LookBackError("Retrieved frame is from different episode!")
|
|
||||||
else:
|
|
||||||
raise LookBackError("Cannot go back further than the history buffer!")
|
|
||||||
|
|
||||||
except LookBackError:
|
|
||||||
target_frames.append(torch.zeros_like(current_item[key]))
|
|
||||||
is_pad.append(True)
|
|
||||||
|
|
||||||
elif delta > 0:
|
|
||||||
# Future frame - read ahead from the iterator
|
|
||||||
try:
|
|
||||||
if dataset_iterator.can_peek_ahead(delta):
|
|
||||||
future_item = dataset_iterator.peek_ahead(delta)
|
|
||||||
future_item = item_to_torch(future_item)
|
|
||||||
|
|
||||||
# Check if it's from the same episode
|
|
||||||
if future_item["episode_index"] == current_episode_idx:
|
|
||||||
target_frames.append(future_item[key])
|
|
||||||
is_pad.append(False)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise LookAheadError("Retrieved frame is from different episode!")
|
|
||||||
else:
|
|
||||||
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
|
||||||
|
|
||||||
except LookAheadError:
|
|
||||||
target_frames.append(torch.zeros_like(current_item[key]))
|
|
||||||
is_pad.append(True)
|
|
||||||
|
|
||||||
# Stack frames and add to results
|
# Stack frames and add to results
|
||||||
if target_frames:
|
if target_frames:
|
||||||
|
|||||||
Reference in New Issue
Block a user