diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 3e8359b8c..22204b859 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -9,7 +9,6 @@ import torch from datasets import concatenate_datasets from datasets.utils import disable_progress_bars, enable_progress_bars from diffusers.optimization import get_scheduler -from torch import Tensor from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle @@ -25,67 +24,47 @@ from lerobot.common.utils.utils import ( from lerobot.scripts.eval import eval_policy -def update_diffusion(policy, batch: dict[str, Tensor], optimizer, lr_scheduler) -> dict: - """Run the model in train mode, compute the loss, and do an optimization step.""" +def update_policy(policy, batch, optimizer, lr_scheduler=None): start_time = time.time() + + # Diffusion policy.diffusion.train() + # Act + policy.train() + batch = policy.normalize_inputs(batch) - loss = policy.forward(batch)["loss"] + + output_dict = policy.forward(batch) + # TODO(rcadene): policy.unnormalize_outputs(out_dict) + loss = output_dict["loss"] loss.backward() - # TODO(rcadene): self.unnormalize_outputs(out_dict) - + # Diffusion + model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line grad_norm = torch.nn.utils.clip_grad_norm_( - policy.diffusion.parameters(), + model.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False, ) optimizer.step() optimizer.zero_grad() - lr_scheduler.step() - - if policy.ema is not None: - policy.ema.step(policy.diffusion) + if lr_scheduler is not None: + lr_scheduler.step() + if hasattr(policy, "ema") and policy.ema is not None: + policy.ema.step(model) + info = { "loss": loss.item(), "grad_norm": float(grad_norm), - "lr": lr_scheduler.get_last_lr()[0], + "lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0], "update_s": time.time() - start_time, } return info - -def update_act(policy, batch: dict[str, Tensor], optimizer) -> dict: - start_time = time.time() - policy.train() - batch = policy.normalize_inputs(batch) - loss_dict = policy.forward(batch) - # TODO(rcadene): policy.unnormalize_outputs(out_dict) - loss = loss_dict["loss"] - loss.backward() - - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False - ) - - optimizer.step() - optimizer.zero_grad() - - train_info = { - "loss": loss.item(), - "grad_norm": float(grad_norm), - "lr": policy.cfg.lr, - "update_s": time.time() - start_time, - } - - return train_info - - - @hydra.main(version_base=None, config_name="default", config_path="../configs") def train_cli(cfg: dict): train( @@ -316,6 +295,7 @@ def train(cfg: dict, out_dir=None, job_name=None): optimizer = torch.optim.AdamW( optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay ) + lr_scheduler = None elif isinstance(policy, DiffusionPolicy): optimizer = torch.optim.Adam( policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay @@ -394,11 +374,7 @@ def train(cfg: dict, out_dir=None, job_name=None): for key in batch: batch[key] = batch[key].to(cfg.device, non_blocking=True) - # Temporary hack to move update outside of policy - if isinstance(policy, ActPolicy): - train_info = update_act(policy, batch, optimizer) - elif isinstance(policy, DiffusionPolicy): - train_info = update_diffusion(policy, batch, optimizer, lr_scheduler) + train_info = update_policy(policy, batch, optimizer, lr_scheduler) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.log_freq == 0: