|
|
|
|
@@ -11,29 +11,39 @@ def load_previous_and_future_frames(
|
|
|
|
|
item: dict[str, torch.Tensor],
|
|
|
|
|
hf_dataset: datasets.Dataset,
|
|
|
|
|
delta_timestamps: dict[str, list[float]],
|
|
|
|
|
tol: float = 0.04,
|
|
|
|
|
tol: float,
|
|
|
|
|
) -> dict[torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
|
|
|
|
|
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
|
|
|
|
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
|
|
|
|
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
|
|
|
|
given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
|
|
|
|
|
|
|
|
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
|
|
|
|
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
|
|
|
|
the violation of the tolerance doesnt raise an AssertionError, and the function populates a boolean array indicating which frames are outside of the episode range.
|
|
|
|
|
For instance, this boolean array is useful during batched training to not supervise actions associated to timestamps coming after the end of the episode,
|
|
|
|
|
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
|
|
|
|
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
|
|
|
|
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
|
|
|
|
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
|
|
|
|
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
|
|
|
|
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
|
|
|
|
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
|
|
|
|
of the episode are the same as the first frame of the episode.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
|
|
|
|
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
|
|
|
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
|
|
|
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
|
|
|
|
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
|
|
|
|
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
|
|
|
|
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
|
|
|
|
modality (e.g., "timestamp", "observation.image", "action").
|
|
|
|
|
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
|
|
|
|
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
|
|
|
|
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query
|
|
|
|
|
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
|
|
|
|
smallest expected inter-frame period, but large enough to account for jitter.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
|
|
|
|
|
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
|
|
|
|
each modality (e.g. "observation.image_is_pad").
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
|
|
|
|
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
|
|
|
|
issues with timestamps during data collection.
|
|
|
|
|
"""
|
|
|
|
|
# get indices of the frames associated to the episode, and their timestamps
|
|
|
|
|
ep_data_id_from = item["episode_data_index_from"].item()
|
|
|
|
|
|