From cd76980d50805ed2eec6b4fa46ddba997a8f9af8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:05:33 +0200 Subject: [PATCH] fix update --- lerobot/scripts/train.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index ef2d019b3..d3e0caa2d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,13 +24,11 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def update_policy(policy, batch, optimizer, lr_scheduler=None): +def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None): start_time = time.time() - # Diffusion - policy.diffusion.train() - # Act - policy.train() + model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line + model.train() batch = policy.normalize_inputs(batch) @@ -43,7 +41,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), - policy.cfg.grad_clip_norm, + cfg.grad_clip_norm, error_if_nonfinite=False, ) @@ -58,7 +56,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): info = { "loss": loss.item(), "grad_norm": float(grad_norm), - "lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0], + "lr": optimizer.param_groups[0]['lr'], "update_s": time.time() - start_time, } @@ -283,10 +281,10 @@ def train(cfg: dict, out_dir=None, job_name=None): {"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]}, { "params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad], - "lr": policy.cfg.lr_backbone, + "lr": cfg.lr_backbone, }, ] - optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay) + optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) lr_scheduler = None elif isinstance(policy, DiffusionPolicy): optimizer = torch.optim.Adam( @@ -364,7 +362,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, lr_scheduler) + train_info = update_policy(cfg, policy, batch, optimizer, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: