From 33362dbd173be13568cc8e166a8e38d4a712d08b Mon Sep 17 00:00:00 2001 From: amandip7 Date: Tue, 4 Jun 2024 21:32:05 +0530 Subject: [PATCH] =?UTF-8?q?Adding=20parameter=20dataloading=5Fs=20to=20con?= =?UTF-8?q?sole=20logs=20and=20wandb=20for=20tracking=E2=80=A6=20(#243)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Remi --- lerobot/scripts/train.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 860412b..22f212a 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -150,6 +150,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): grad_norm = info["grad_norm"] lr = info["lr"] update_s = info["update_s"] + dataloading_s = info["dataloading_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. @@ -170,6 +171,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): f"lr:{lr:0.1e}", # in seconds f"updt_s:{update_s:.3f}", + f"data_s:{dataloading_s:.3f}", # if not ~0, you are bottlenecked by cpu or io ] logging.info(" ".join(log_items)) @@ -382,7 +384,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No for _ in range(step, cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") + + start_time = time.perf_counter() batch = next(dl_iter) + dataloading_s = time.perf_counter() - start_time for key in batch: batch[key] = batch[key].to(device, non_blocking=True) @@ -397,6 +402,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No use_amp=cfg.use_amp, ) + train_info["dataloading_s"] = dataloading_s + if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline=True)