Loads episode_data_index and stats during dataset __init__ (#85)
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com> Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -136,6 +136,7 @@ def add_episodes_inplace(
|
||||
concat_dataset: torch.utils.data.ConcatDataset,
|
||||
sampler: torch.utils.data.WeightedRandomSampler,
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
pc_online_samples: float,
|
||||
):
|
||||
"""
|
||||
@@ -151,13 +152,15 @@ def add_episodes_inplace(
|
||||
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
|
||||
reflect changes in the dataset sizes and specified sampling weights.
|
||||
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- pc_online_samples (float): The target percentage of samples that should come from
|
||||
the online dataset during sampling operations.
|
||||
|
||||
Raises:
|
||||
- AssertionError: If the first episode_id or index in hf_dataset is not 0
|
||||
"""
|
||||
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
|
||||
first_episode_id = hf_dataset.select_columns("episode_index")[0]["episode_index"].item()
|
||||
first_index = hf_dataset.select_columns("index")[0]["index"].item()
|
||||
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
|
||||
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
|
||||
@@ -167,21 +170,22 @@ def add_episodes_inplace(
|
||||
online_dataset.hf_dataset = hf_dataset
|
||||
else:
|
||||
# find episode index and data frame indices according to previous episode in online_dataset
|
||||
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
|
||||
start_episode = online_dataset.select_columns("episode_index")[-1]["episode_index"].item() + 1
|
||||
start_index = online_dataset.select_columns("index")[-1]["index"].item() + 1
|
||||
|
||||
def shift_indices(example):
|
||||
# note: we dont shift "frame_id" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_id"] += start_episode
|
||||
# note: we dont shift "frame_index" since it represents the index of the frame in the episode it belongs to
|
||||
example["episode_index"] += start_episode
|
||||
example["index"] += start_index
|
||||
example["episode_data_index_from"] += start_index
|
||||
example["episode_data_index_to"] += start_index
|
||||
return example
|
||||
|
||||
disable_progress_bars() # map has a tqdm progress bar
|
||||
hf_dataset = hf_dataset.map(shift_indices)
|
||||
enable_progress_bars()
|
||||
|
||||
episode_data_index["from"] += start_index
|
||||
episode_data_index["to"] += start_index
|
||||
|
||||
# extend online dataset
|
||||
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
|
||||
|
||||
@@ -334,9 +338,13 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
seed=cfg.seed,
|
||||
)
|
||||
|
||||
online_pc_sampling = cfg.get("demo_schedule", 0.5)
|
||||
add_episodes_inplace(
|
||||
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
|
||||
online_dataset,
|
||||
concat_dataset,
|
||||
sampler,
|
||||
hf_dataset=eval_info["episodes"]["hf_dataset"],
|
||||
episode_data_index=eval_info["episodes"]["episode_data_index"],
|
||||
pc_online_samples=cfg.get("demo_schedule", 0.5),
|
||||
)
|
||||
|
||||
for _ in range(cfg.policy.utd):
|
||||
|
||||
Reference in New Issue
Block a user