For Pusht: use hf datasets to train, rename load_data_with_delta_timestamps -> load_previous_and_future_frames
This commit is contained in:
@@ -34,18 +34,15 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def load_data_with_delta_timestamps(
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[torch.Tensor],
|
||||
data_dict: dict[torch.Tensor],
|
||||
data_ids_per_episode: dict[torch.Tensor],
|
||||
delta_timestamps: list[float],
|
||||
key: str,
|
||||
current_ts: float,
|
||||
episode: int,
|
||||
delta_timestamps: dict[list[float]],
|
||||
tol: float = 0.04,
|
||||
):
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
|
||||
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
|
||||
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
|
||||
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
|
||||
@@ -54,56 +51,57 @@ def load_data_with_delta_timestamps(
|
||||
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
|
||||
|
||||
Parameters:
|
||||
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
|
||||
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
|
||||
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
|
||||
- episode (int): The identifier of the episode from which frames are to be retrieved.
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
|
||||
|
||||
Returns:
|
||||
- tuple: A tuple containing two elements:
|
||||
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
|
||||
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
|
||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_ids = data_ids_per_episode[episode]
|
||||
ep_timestamps = data_dict["timestamp"][ep_data_ids]
|
||||
ep_data_id_from = item["episode_data_id_from"].item()
|
||||
ep_data_id_to = item["episode_data_id_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
current_ts = item["timestamp"].item()
|
||||
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
for key in delta_timestamps:
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# get the indices of the data that are closest to the query timestamps
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
# closest_ts = ep_timestamps[argmin_]
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
# get the data
|
||||
data = data_dict[key][data_ids].clone()
|
||||
is_pad = min_ > tol
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
is_pad = min_ > tol
|
||||
# get dataset indices corresponding to frames to be loaded
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
# load frames modality
|
||||
item[key] = data_dict.select_columns(key)[data_ids][key]
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return data, is_pad
|
||||
return item
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset):
|
||||
|
||||
Reference in New Issue
Block a user