Improve dataset examples (#82)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-18 11:43:16 +02:00
committed by GitHub
parent d5c4b0c344
commit 0928afd37d
15 changed files with 274 additions and 165 deletions

View File

@@ -241,7 +241,7 @@ def eval_policy(
data_dict["index"] = torch.arange(0, total_frames, 1)
data_dict = Dataset.from_dict(data_dict).with_format("torch")
hf_dataset = Dataset.from_dict(data_dict).with_format("torch")
if max_episodes_rendered > 0:
batch_stacked_frames = np.stack(ep_frames, 1) # (b, t, *)
@@ -292,7 +292,7 @@ def eval_policy(
"eval_s": time.time() - start,
"eval_ep_s": (time.time() - start) / num_episodes,
},
"episodes": data_dict,
"episodes": hf_dataset,
}
if max_episodes_rendered > 0:
info["videos"] = videos

View File

@@ -2,10 +2,11 @@ import logging
from copy import deepcopy
from pathlib import Path
import datasets
import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils.logging import disable_progress_bar
from datasets.utils import disable_progress_bars, enable_progress_bars
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -130,15 +131,40 @@ def calculate_online_sample_weight(n_off: int, n_on: int, pc_on: float):
return -(n_off * pc_on) / (n_on * (pc_on - 1))
def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_online_samples):
first_episode_id = data_dict.select_columns("episode_id")[0]["episode_id"].item()
first_index = data_dict.select_columns("index")[0]["index"].item()
def add_episodes_inplace(
online_dataset: torch.utils.data.Dataset,
concat_dataset: torch.utils.data.ConcatDataset,
sampler: torch.utils.data.WeightedRandomSampler,
hf_dataset: datasets.Dataset,
pc_online_samples: float,
):
"""
Modifies the online_dataset, concat_dataset, and sampler in place by integrating
new episodes from hf_dataset into the online_dataset, updating the concatenated
dataset's structure and adjusting the sampling strategy based on the specified
percentage of online samples.
Parameters:
- online_dataset (torch.utils.data.Dataset): The existing online dataset to be updated.
- concat_dataset (torch.utils.data.ConcatDataset): The concatenated dataset that combines
offline and online datasets, used for sampling purposes.
- sampler (torch.utils.data.WeightedRandomSampler): A sampler that will be updated to
reflect changes in the dataset sizes and specified sampling weights.
- hf_dataset (datasets.Dataset): A Hugging Face dataset containing the new episodes to be added.
- pc_online_samples (float): The target percentage of samples that should come from
the online dataset during sampling operations.
Raises:
- AssertionError: If the first episode_id or index in hf_dataset is not 0
"""
first_episode_id = hf_dataset.select_columns("episode_id")[0]["episode_id"].item()
first_index = hf_dataset.select_columns("index")[0]["index"].item()
assert first_episode_id == 0, f"We expect the first episode_id to be 0 and not {first_episode_id}"
assert first_index == 0, f"We expect the first first_index to be 0 and not {first_index}"
if len(online_dataset) == 0:
# initialize online dataset
online_dataset.data_dict = data_dict
online_dataset.hf_dataset = hf_dataset
else:
# find episode index and data frame indices according to previous episode in online_dataset
start_episode = online_dataset.select_columns("episode_id")[-1]["episode_id"].item() + 1
@@ -152,11 +178,12 @@ def add_episodes_inplace(data_dict, online_dataset, concat_dataset, sampler, pc_
example["episode_data_index_to"] += start_index
return example
disable_progress_bar() # map has a tqdm progress bar
data_dict = data_dict.map(shift_indices)
disable_progress_bars() # map has a tqdm progress bar
hf_dataset = hf_dataset.map(shift_indices)
enable_progress_bars()
# extend online dataset
online_dataset.data_dict = concatenate_datasets([online_dataset.data_dict, data_dict])
online_dataset.hf_dataset = concatenate_datasets([online_dataset.hf_dataset, hf_dataset])
# update the concatenated dataset length used during sampling
concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets)
@@ -274,7 +301,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.data_dict = {}
online_dataset.hf_dataset = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
@@ -308,7 +335,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
online_pc_sampling = cfg.get("demo_schedule", 0.5)
add_episodes_inplace(
eval_info["episodes"], online_dataset, concat_dataset, sampler, online_pc_sampling
online_dataset, concat_dataset, sampler, eval_info["episodes"], online_pc_sampling
)
for _ in range(cfg.policy.utd):