fix update

This commit is contained in:
Quentin Gallouédec
2024-04-25 12:05:33 +02:00
parent 1ffc0e0d94
commit cd76980d50

View File

@@ -24,13 +24,11 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def update_policy(policy, batch, optimizer, lr_scheduler=None):
def update_policy(cfg, policy, batch, optimizer, lr_scheduler=None):
start_time = time.time()
# Diffusion
policy.diffusion.train()
# Act
policy.train()
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
model.train()
batch = policy.normalize_inputs(batch)
@@ -43,7 +41,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None):
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
policy.cfg.grad_clip_norm,
cfg.grad_clip_norm,
error_if_nonfinite=False,
)
@@ -58,7 +56,7 @@ def update_policy(policy, batch, optimizer, lr_scheduler=None):
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0],
"lr": optimizer.param_groups[0]['lr'],
"update_s": time.time() - start_time,
}
@@ -283,10 +281,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
{
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
"lr": policy.cfg.lr_backbone,
"lr": cfg.lr_backbone,
},
]
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay)
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.lr, weight_decay=cfg.weight_decay)
lr_scheduler = None
elif isinstance(policy, DiffusionPolicy):
optimizer = torch.optim.Adam(
@@ -364,7 +362,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = update_policy(policy, batch, optimizer, lr_scheduler)
train_info = update_policy(cfg, policy, batch, optimizer, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: