Add MultiLerobotDataset for training with multiple LeRobotDatasets (#229)

This commit is contained in:
Alexander Soare
2024-05-30 16:12:21 +01:00
committed by GitHub
parent 265b0ec44d
commit 111cd58f8a
8 changed files with 352 additions and 72 deletions

View File

@@ -16,7 +16,6 @@
import logging
import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
@@ -28,6 +27,7 @@ from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@@ -280,6 +280,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -330,7 +335,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
max_episodes_rendered=4,
start_seed=cfg.seed,
)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline)
log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline=True)
if cfg.wandb.enable:
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
@@ -362,7 +367,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
dl_iter = cycle(dataloader)
policy.train()
is_offline = True
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
@@ -382,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
@@ -390,41 +394,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
logging.info("End of offline training")
if cfg.training.online_steps == 0:
if cfg.training.eval_freq > 0:
eval_env.close()
return
# create an env dedicated to online episodes collection from policy rollout
online_training_env = make_env(cfg, n_envs=1)
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
online_dataset.episode_data_index = {}
# create dataloader for online training
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
weights = [1.0] * len(concat_dataset)
sampler = torch.utils.data.WeightedRandomSampler(
weights, num_samples=len(concat_dataset), replacement=True
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
num_workers=4,
batch_size=cfg.training.batch_size,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=False,
)
logging.info("End of online training")
if cfg.training.eval_freq > 0:
eval_env.close()
online_training_env.close()
eval_env.close()
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")