Address comments
This commit is contained in:
@@ -7,9 +7,9 @@ import tqdm
|
||||
|
||||
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[torch.Tensor],
|
||||
data_dict: dict[torch.Tensor],
|
||||
delta_timestamps: dict[list[float]],
|
||||
item: dict[str, torch.Tensor],
|
||||
data_dict: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tol: float = 0.04,
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
@@ -35,12 +35,12 @@ def load_previous_and_future_frames(
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_data_id_from = item["episode_data_id_from"].item()
|
||||
ep_data_id_to = item["episode_data_id_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to + 1, 1)
|
||||
ep_data_id_from = item["episode_data_index_from"].item()
|
||||
ep_data_id_to = item["episode_data_index_to"].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from : ep_data_id_to + 1]["timestamp"]
|
||||
ep_timestamps = data_dict.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
|
||||
@@ -215,8 +215,8 @@ def eval_policy(
|
||||
"timestamp": torch.arange(0, num_frames, 1) / fps,
|
||||
"next.done": dones[ep_id, :num_frames],
|
||||
"next.reward": rewards[ep_id, :num_frames].type(torch.float32),
|
||||
"episode_data_id_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_id_to": torch.tensor([idx_from + num_frames - 1] * num_frames),
|
||||
"episode_data_index_from": torch.tensor([idx_from] * num_frames),
|
||||
"episode_data_index_to": torch.tensor([idx_from + num_frames] * num_frames),
|
||||
}
|
||||
for key in observations:
|
||||
ep_dict[key] = observations[key][ep_id][:num_frames]
|
||||
|
||||
@@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
|
||||
online_dataset.data_dict = data_dict
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
|
||||
start_index = online_dataset.data_dict["index"][-1].item() + 1
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_id_from"] += start_index
|
||||
example["episode_data_id_to"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bar() # map has a tqdm progress bar
|
||||
|
||||
@@ -77,7 +77,7 @@ def render_dataset(dataset, out_dir, max_num_episodes):
|
||||
# add current frame to list of frames to render
|
||||
frames[im_key].append(item[im_key])
|
||||
|
||||
end_of_episode = item["index"].item() == item["episode_data_id_to"].item()
|
||||
end_of_episode = item["index"].item() == item["episode_data_index_to"].item()
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
for im_key in dataset.image_keys:
|
||||
|
||||
Reference in New Issue
Block a user