diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index ef2d019b3..d3e0caa2d 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,13 +24,11 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def update_policy(policy, batch, optimizer, lr_scheduler=None): +def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None): start_time = time.time() - # Diffusion - policy.diffusion.train() - # Act - policy.train() + model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line + model.train() batch = policy.normalize_inputs(batch) @@ -43,7 +41,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), - policy.cfg.grad_clip_norm, + cfg.grad_clip_norm, error_if_nonfinite=False, ) @@ -58,7 +56,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): info = { "loss": loss.item(), "grad_norm": float(grad_norm), - "lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0], + "lr": optimizer.param_groups[0]['lr'], "update_s": time.time() - start_time, } @@ -283,10 +281,10 @@ def train(cfg: dict, out_dir=None, job_name=None): {"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]}, { "params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad], - "lr": policy.cfg.lr_backbone, + "lr": cfg.lr_backbone, }, ] - optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay) + optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) lr_scheduler = None elif isinstance(policy, DiffusionPolicy): optimizer = torch.optim.Adam( @@ -364,7 +362,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = update_policy(policy, batch, optimizer, lr_scheduler) + train_info = update_policy(cfg, policy, batch, optimizer, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: