For Pusht: use hf datasets to train, rename load_data_with_delta_timestamps -> load_previous_and_future_frames

This commit is contained in:
Cadene
2024-04-15 10:08:10 +00:00
parent 4ed55c3ba3
commit c6aca7fe44
5 changed files with 920 additions and 113 deletions

View File

@@ -34,18 +34,15 @@ def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
return False
def load_data_with_delta_timestamps(
def load_previous_and_future_frames(
item: dict[torch.Tensor],
data_dict: dict[torch.Tensor],
data_ids_per_episode: dict[torch.Tensor],
delta_timestamps: list[float],
key: str,
current_ts: float,
episode: int,
delta_timestamps: dict[list[float]],
tol: float = 0.04,
):
) -> dict[torch.Tensor]:
"""
Given a current timestamp (e.g. current_ts=0.6) and a list of timestamps differences (e.g. delta_timestamps=[-0.8, -0.2, 0, 0.2]),
this function compute the query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames of the specified modality (e.g. key="observation.image").
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}),
this function computes for each given modality a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest frames in the dataset.
Importantly, when no frame can be found around a query timestamp within a specified tolerance window (e.g. tol=0.04), this function raises an AssertionError.
When a timestamp is queried before the first available timestamp of the episode or after the last available timestamp,
@@ -54,56 +51,57 @@ def load_data_with_delta_timestamps(
or to pad the observations in a specific way. Note that by default the observation frames before the start of the episode are the same as the first frame of the episode.
Parameters:
- data_dict (dict): A dictionary containing the data, where each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_ids_per_episode (dict): A dictionary where keys are episode identifiers and values are lists of indices corresponding to frames associated with each episode.
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible key to be retrieved. These deltas are added to the current_ts to form the query timestamps.
- key (str): The key specifying which data modality is to be retrieved from the data_dict.
- current_ts (float): The current timestamp to which the delta timestamps are added to form the query timestamps.
- episode (int): The identifier of the episode from which frames are to be retrieved.
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- data_dict (dict): A dictionary containing the full dataset. Each key corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be retrieved. These deltas are added to the item timestamp to form the query timestamps.
- tol (float, optional): The tolerance level used to determine if a data point is close enough to the query timestamp. Defaults to 0.04.
Returns:
- tuple: A tuple containing two elements:
- The first element is the data retrieved from the specified modality based on the closest match to the query timestamps.
- The second element is a boolean array indicating which frames were considered as padding (True if the distance to the closest timestamp was greater than the tolerance level).
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for each modality (e.g. "observation.image_is_pad").
Raises:
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
"""
# get indices of the frames associated to the episode, and their timestamps
ep_data_ids = data_ids_per_episode[episode]
ep_timestamps = data_dict["timestamp"][ep_data_ids]
ep_data_id_from = item["episode_data_id_from"].item()
ep_data_id_to = item["episode_data_id_to"].item()
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1)
# load timestamps
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"]
# we make the assumption that the timestamps are sorted
ep_first_ts = ep_timestamps[0]
ep_last_ts = ep_timestamps[-1]
current_ts = item["timestamp"].item()
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
for key in delta_timestamps:
# get timestamps used as query to retrieve data of previous/future frames
delta_ts = delta_timestamps[key]
query_ts = current_ts + torch.tensor(delta_ts)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
min_, argmin_ = dist.min(1)
# get the indices of the data that are closest to the query timestamps
data_ids = ep_data_ids[argmin_]
# closest_ts = ep_timestamps[argmin_]
# TODO(rcadene): synchronize timestamps + interpolation if needed
# get the data
data = data_dict[key][data_ids].clone()
is_pad = min_ > tol
# TODO(rcadene): synchronize timestamps + interpolation if needed
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
is_pad = min_ > tol
# get dataset indices corresponding to frames to be loaded
data_ids = ep_data_ids[argmin_]
# check violated query timestamps are all outside the episode range
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tol=}) inside episode range."
"This might be due to synchronization issues with timestamps during data collection."
)
# load frames modality
item[key] = data_dict.select_columns(key)[data_ids][key]
item[f"{key}_is_pad"] = is_pad
return data, is_pad
return item
def get_stats_einops_patterns(dataset):