From 63e61385fccfcbc5c2f633235eeabf9de18219f0 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Wed, 29 May 2024 11:09:16 +0000 Subject: [PATCH] Add num_workers and batch_size to default.yaml --- lerobot/configs/default.yaml | 2 ++ lerobot/scripts/train.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index ae36b3e2..1f4b6027 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -26,6 +26,8 @@ training: save_freq: ??? log_freq: 250 save_model: true + num_workers: 4 + batch_size: ??? eval: n_episodes: 1 diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index c81647f3..c6313792 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -386,7 +386,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, - num_workers=4, + num_workers=cfg.training.num_workers, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=cfg.device != "cpu",