Address comments

This commit is contained in:
Cadene
2024-04-16 17:14:40 +00:00
parent b241ea46dd
commit 36d9e885ef
24 changed files with 100 additions and 94 deletions

View File

@@ -141,15 +141,15 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
online_dataset.data_dict = data_dict
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.data_dict["episode_id"][-1].item() + 1
start_index = online_dataset.data_dict["index"][-1].item() + 1
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
example["index"] += start_index
example["episode_data_id_from"] += start_index
example["episode_data_id_to"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bar() # map has a tqdm progress bar