From 3af6768f37d8f33ce72c7717efb1c6a953efaaa9 Mon Sep 17 00:00:00 2001 From: fracapuano Date: Thu, 29 May 2025 23:14:18 +0200 Subject: [PATCH] add: next() & prev() iterable for streaming dataset with two-sided memory buffers --- lerobot/common/datasets/utils.py | 181 ++++++++++++++++++++++++++++++- 1 file changed, 179 insertions(+), 2 deletions(-) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 96eeebc56..7c419893d 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -20,11 +20,12 @@ import logging import shutil import subprocess import tempfile -from collections.abc import Iterator +from collections import deque +from collections.abc import Iterable, Iterator from pathlib import Path from pprint import pformat from types import SimpleNamespace -from typing import Any, TypeVar +from typing import Any, Deque, Generic, List, TypeVar import datasets import numpy as np @@ -921,3 +922,179 @@ def item_to_torch(item: dict) -> dict: # Convert numpy arrays and lists to torch tensors item[key] = torch.tensor(val) return item + + +class LookBackError(Exception): + """ + Exception raised when trying to look back in the history of a Backtrackable object. + """ + + pass + + +class LookAheadError(Exception): + """ + Exception raised when trying to look ahead in the future of a Backtrackable object. + """ + + pass + + +class Backtrackable(Generic[T]): + """ + Wrap any iterator/iterable so you can step back up to `history` items + and look ahead up to `lookahead` items. + + This is useful for streaming datasets where you need to access previous and future items + but can't load the entire dataset into memory. + + Example: + ------- + ```python + ds = load_dataset("c4", "en", streaming=True, split="train") + rev = Backtrackable(ds, history=3, lookahead=2) + + x0 = next(rev) # forward + x1 = next(rev) + x2 = next(rev) + + # Look ahead + x3_peek = rev.peek_ahead(1) # next item without moving cursor + x4_peek = rev.peek_ahead(2) # two items ahead + + # Look back + x1_again = rev.peek_back(1) # previous item without moving cursor + x0_again = rev.peek_back(2) # two items back + + # Move backward + x1_back = rev.prev() # back one step + next(rev) # returns x2, continues forward from where we were + ``` + """ + + __slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_lookahead") + + def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0): + if history < 1: + raise ValueError("history must be ≥ 1") + if lookahead < 0: + raise ValueError("lookahead must be ≥ 0") + + self._source: Iterator[T] = iter(iterable) + self._back_buf: Deque[T] = deque(maxlen=history) + self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque() + self._cursor: int = 0 # 0 == just after the newest item in back_buf + self._lookahead = lookahead + + def __iter__(self) -> "Backtrackable[T]": + return self + + def __next__(self) -> T: + # If we've stepped back, consume from back buffer first + if self._cursor < 0: # -1 means "last item", etc. + self._cursor += 1 + return self._back_buf[self._cursor] + + # If we have items in the ahead buffer, use them first + item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source) + + # Add current item to back buffer and reset cursor + self._back_buf.append(item) + self._cursor = 0 + return item + + def prev(self) -> T: + """ + Step one item back in history and return it. + Raises IndexError if already at the oldest buffered item. + """ + if len(self._back_buf) + self._cursor <= 1: + raise LookBackError("At start of history") + + self._cursor -= 1 + return self._back_buf[self._cursor] + + def peek_back(self, n: int = 1) -> T: + """ + Look `n` items back (n=1 == previous item) without moving the cursor. + """ + if n < 1 or n > len(self._back_buf) + self._cursor: + raise LookBackError("peek_back distance out of range") + + return self._back_buf[self._cursor - n] + + def peek_ahead(self, n: int = 1) -> T: + """ + Look `n` items ahead (n=1 == next item) without moving the cursor. + Fills the ahead buffer if necessary. + """ + if n < 1: + raise LookAheadError("peek_ahead distance must be 1 or more") + elif n > self._lookahead: + raise LookAheadError("peek_ahead distance exceeds lookahead limit") + + # Fill ahead buffer if we don't have enough items + while len(self._ahead_buf) < n: + try: + item = next(self._source) + self._ahead_buf.append(item) + + except StopIteration as err: + raise LookAheadError("peek_ahead: not enough items in source") from err + + return self._ahead_buf[n - 1] + + def history(self) -> List[T]: + """ + Return a copy of the buffered history (most recent last). + The list length ≤ `history` argument passed at construction. + """ + if self._cursor == 0: + return list(self._back_buf) + + # When cursor<0, slice so the order remains chronological + return list(self._back_buf)[: self._cursor or None] + + def lookahead_buffer(self) -> List[T]: + """ + Return a copy of the current lookahead buffer. + """ + return list(self._ahead_buf) + + def can_peek_back(self, steps: int = 1) -> bool: + """ + Check if we can go back `steps` items without raising an IndexError. + """ + return steps <= len(self._back_buf) + self._cursor + + def can_peek_ahead(self, steps: int = 1) -> bool: + """ + Check if we can peek ahead `steps` items. + This may involve trying to fill the ahead buffer. + """ + if self._lookahead > 0 and steps > self._lookahead: + return False + + # Try to fill ahead buffer to check if we can peek that far + try: + while len(self._ahead_buf) < steps: + if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead: + return False + item = next(self._source) + self._ahead_buf.append(item) + return True + except StopIteration: + return False + + def reset_cursor(self) -> None: + """ + Reset cursor to the most recent position (equivalent to calling next() + until you're back to the latest item). + """ + self._cursor = 0 + + def clear_ahead_buffer(self) -> None: + """ + Clear the ahead buffer, discarding any pre-fetched items. + """ + self._ahead_buf.clear()