From 0bca982fca5d44bc44ac8e5f3f86ee93c5c4dd7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:26:40 +0200 Subject: [PATCH] move optimizer and scheduler outside policies --- lerobot/common/policies/act/modeling_act.py | 19 ---- .../policies/diffusion/modeling_diffusion.py | 19 ---- lerobot/scripts/train.py | 105 ++++++++++++------ 3 files changed, 72 insertions(+), 71 deletions(-) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index c727988b5..597a5cb16 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -127,25 +127,6 @@ class ActionChunkingTransformerPolicy(nn.Module): self.action_head = nn.Linear(cfg.d_model, cfg.action_dim) self._reset_parameters() - self._create_optimizer() - - def _create_optimizer(self): - optimizer_params_dicts = [ - { - "params": [ - p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad - ], - "lr": self.cfg.lr_backbone, - }, - ] - self.optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay - ) def _reset_parameters(self): """Xavier-uniform initialization of the transformer parameters as in the original code.""" diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 4f488c015..d34db1c66 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -18,7 +18,6 @@ import einops import torch import torch.nn.functional as F # noqa: N812 import torchvision -from diffusers.optimization import get_scheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from robomimic.models.base_nets import SpatialSoftmax from torch import Tensor, nn @@ -70,25 +69,7 @@ class DiffusionPolicy(nn.Module): self.ema_diffusion = copy.deepcopy(self.diffusion) self.ema = _EMA(cfg, model=self.ema_diffusion) - # TODO(alexander-soare): Move optimizer out of policy. - self.optimizer = torch.optim.Adam( - self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay - ) - # TODO(alexander-soare): Move LR scheduler out of policy. - # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps - self.global_step = 0 - - # configure lr scheduler - self.lr_scheduler = get_scheduler( - cfg.lr_scheduler, - optimizer=self.optimizer, - num_warmup_steps=cfg.lr_warmup_steps, - num_training_steps=lr_scheduler_num_training_steps, - # pytorch assumes stepping LRScheduler every epoch - # however huggingface diffusers steps it every batch - last_epoch=self.global_step - 1, - ) def reset(self): """ diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 4ff295733..3e8359b8c 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -8,6 +8,7 @@ import hydra import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars +from diffusers.optimization import get_scheduler from torch import Tensor from lerobot.common.datasets.factory import make_dataset @@ -24,41 +25,41 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" - start_time = time.time() - policy.diffusion.train() - batch = policy.normalize_inputs(batch) - loss = policy.forward(batch)["loss"] - loss.backward() +def update_diffusion(policy, batch: dict[str, Tensor], optimizer, lr_scheduler) -> dict: + """Run the model in train mode, compute the loss, and do an optimization step.""" + start_time = time.time() + policy.diffusion.train() + batch = policy.normalize_inputs(batch) + loss = policy.forward(batch)["loss"] + loss.backward() - # TODO(rcadene): self.unnormalize_outputs(out_dict) + # TODO(rcadene): self.unnormalize_outputs(out_dict) - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.diffusion.parameters(), - policy.cfg.grad_clip_norm, - error_if_nonfinite=False, - ) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.diffusion.parameters(), + policy.cfg.grad_clip_norm, + error_if_nonfinite=False, + ) - policy.optimizer.step() - policy.optimizer.zero_grad() - policy.lr_scheduler.step() + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step() - if policy.ema is not None: - policy.ema.step(policy.diffusion) + if policy.ema is not None: + policy.ema.step(policy.diffusion) - info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": policy.lr_scheduler.get_last_lr()[0], - "update_s": time.time() - start_time, - } + info = { + "loss": loss.item(), + "grad_norm": float(grad_norm), + "lr": lr_scheduler.get_last_lr()[0], + "update_s": time.time() - start_time, + } - return info + return info -def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict: +def update_act(policy, batch: dict[str, Tensor], optimizer) -> dict: start_time = time.time() policy.train() batch = policy.normalize_inputs(batch) @@ -71,8 +72,8 @@ def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict: policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False ) - policy.optimizer.step() - policy.optimizer.zero_grad() + optimizer.step() + optimizer.zero_grad() train_info = { "loss": loss.item(), @@ -83,6 +84,8 @@ def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict: return train_info + + @hydra.main(version_base=None, config_name="default", config_path="../configs") def train_cli(cfg: dict): train( @@ -295,6 +298,43 @@ def train(cfg: dict, out_dir=None, job_name=None): logging.info("make_policy") policy = make_policy(cfg, dataset_stats=offline_dataset.stats) + # Temporary hack to move optimizer out of policy + if isinstance(policy, ActPolicy): + optimizer_params_dicts = [ + { + "params": [ + p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad + ] + }, + { + "params": [ + p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad + ], + "lr": policy.cfg.lr_backbone, + }, + ] + optimizer = torch.optim.AdamW( + optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay + ) + elif isinstance(policy, DiffusionPolicy): + optimizer = torch.optim.Adam( + policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay + ) + # TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps + global_step = 0 + # configure lr scheduler + lr_scheduler = get_scheduler( + cfg.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.lr_warmup_steps, + num_training_steps=cfg.offline_steps, + # pytorch assumes stepping LRScheduler every epoch + # however huggingface diffusers steps it every batch + last_epoch=global_step - 1, + ) + + + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -355,11 +395,10 @@ def train(cfg: dict, out_dir=None, job_name=None): batch[key] = batch[key].to(cfg.device, non_blocking=True) # Temporary hack to move update outside of policy - if isinstance(policy, DiffusionPolicy): - train_info = update_diffusion(policy, batch) - elif isinstance(policy, ActPolicy): - train_info = update_act(policy, batch) - + if isinstance(policy, ActPolicy): + train_info = update_act(policy, batch, optimizer) + elif isinstance(policy, DiffusionPolicy): + train_info = update_diffusion(policy, batch, optimizer, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: