Loads episode_data_index and stats during dataset __init__ (#85)

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-23 14:13:25 +02:00
committed by GitHub
parent e2168163cd
commit 1030ea0070
89 changed files with 1008 additions and 432 deletions

View File

@@ -136,6 +136,7 @@ def add_episodes_inplace(
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
episode_data_index: dict[str, torch.Tensor],
pc_online_samples: float,
):
"""
@@ -151,13 +152,15 @@ def add_episodes_inplace(
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
They indicate the start index and end index of each episode in the dataset.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
@@ -167,21 +170,22 @@ def add_episodes_inplace(
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
def shift_indices(example):
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
example["episode_id"] += start_episode
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
example["episode_index"] += start_episode
example["index"] += start_index
example["episode_data_index_from"] += start_index
example["episode_data_index_to"] += start_index
return example
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
episode_data_index["from"] += start_index
episode_data_index["to"] += start_index
# extend online dataset
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
@@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
seed=cfg.seed,
)
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
online_dataset,
concat_dataset,
sampler,
hf_dataset=eval_info["episodes"]["hf_dataset"],
episode_data_index=eval_info["episodes"]["episode_data_index"],
pc_online_samples=cfg.get("demo_schedule", 0.5),
)
for _ in range(cfg.policy.utd):