From 1ffc0e0d9413d9637204d686adf2848607fcd7ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:46:17 +0200 Subject: [PATCH] format --- lerobot/scripts/train.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 22204b859..ef2d019b3 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -16,10 +16,10 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.utils.utils import ( - format_big_number, - get_safe_torch_device, - init_logging, - set_global_seed, + format_big_number, + get_safe_torch_device, + init_logging, + set_global_seed, ) from lerobot.scripts.eval import eval_policy @@ -40,7 +40,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): loss.backward() # Diffusion - model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line + model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), policy.cfg.grad_clip_norm, @@ -54,7 +54,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None): if hasattr(policy, "ema") and policy.ema is not None: policy.ema.step(model) - + info = { "loss": loss.item(), "grad_norm": float(grad_norm), @@ -280,21 +280,13 @@ def train(cfg: dict, out_dir=None, job_name=None): # Temporary hack to move optimizer out of policy if isinstance(policy, ActPolicy): optimizer_params_dicts = [ + {"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]}, { - "params": [ - p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad - ] - }, - { - "params": [ - p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad - ], + "params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad], "lr": policy.cfg.lr_backbone, }, ] - optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay - ) + optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay) lr_scheduler = None elif isinstance(policy, DiffusionPolicy): optimizer = torch.optim.Adam( @@ -313,8 +305,6 @@ def train(cfg: dict, out_dir=None, job_name=None): last_epoch=global_step - 1, ) - - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -397,9 +387,7 @@ def train(cfg: dict, out_dir=None, job_name=None): # create dataloader for online training concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) weights = [1.0] * len(concat_dataset) - sampler = torch.utils.data.WeightedRandomSampler( - weights, num_samples=len(concat_dataset), replacement=True - ) + sampler = torch.utils.data.WeightedRandomSampler(weights, num_samples=len(concat_dataset), replacement=True) dataloader = torch.utils.data.DataLoader( concat_dataset, num_workers=4,