From 0a33a414fb604b516e2a35756150149308c19a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:12:06 +0200 Subject: [PATCH] grad_clip_norm as arg of update policy --- lerobot/scripts/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index d3e0caa2d..367fd8afe 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -24,7 +24,7 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None): +def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None): start_time = time.time() model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line @@ -41,7 +41,7 @@ def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None): model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), - cfg.grad_clip_norm, + grad_clip_norm, error_if_nonfinite=False, ) @@ -275,8 +275,9 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_policy") policy = make_policy(cfg, dataset_stats=offline_dataset.stats) + # Create optimizer and scheduler # Temporary hack to move optimizer out of policy - if isinstance(policy, ActPolicy): + if cfg.policy.name == "act": optimizer_params_dicts = [ {"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]}, { @@ -286,7 +287,7 @@ def train(cfg: dict, out_dir=None, job_name=None): ] optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay) lr_scheduler = None - elif isinstance(policy, DiffusionPolicy): + elif cfg.policy.name == "diffusion": optimizer = torch.optim.Adam( policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay ) @@ -362,7 +363,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - train_info = update_policy(cfg, policy, batch, optimizer, lr_scheduler) + train_info = update_policy(policy, batch, optimizer, cfg.grad_clip_norm, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: