merge update functions

This commit is contained in:
Quentin Gallouédec
2024-04-25 11:42:20 +02:00
parent 0bca982fca
commit 2a9ea01d5a

View File

@@ -9,7 +9,6 @@ import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler
from torch import Tensor
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -25,67 +24,47 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def update_diffusion(policy, batch: dict[str, Tensor], optimizer, lr_scheduler) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
def update_policy(policy, batch, optimizer, lr_scheduler=None):
start_time = time.time()
# Diffusion
policy.diffusion.train()
# Act
policy.train()
batch = policy.normalize_inputs(batch)
loss = policy.forward(batch)["loss"]
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
# Diffusion
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.diffusion.parameters(),
model.parameters(),
policy.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
if policy.ema is not None:
policy.ema.step(policy.diffusion)
if lr_scheduler is not None:
lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(model)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": lr_scheduler.get_last_lr()[0],
"lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
def update_act(policy, batch: dict[str, Tensor], optimizer) -> dict:
start_time = time.time()
policy.train()
batch = policy.normalize_inputs(batch)
loss_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
)
optimizer.step()
optimizer.zero_grad()
train_info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": policy.cfg.lr,
"update_s": time.time() - start_time,
}
return train_info
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
@@ -316,6 +295,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
)
lr_scheduler = None
elif isinstance(policy, DiffusionPolicy):
optimizer = torch.optim.Adam(
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
@@ -394,11 +374,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
# Temporary hack to move update outside of policy
if isinstance(policy, ActPolicy):
train_info = update_act(policy, batch, optimizer)
elif isinstance(policy, DiffusionPolicy):
train_info = update_diffusion(policy, batch, optimizer, lr_scheduler)
train_info = update_policy(policy, batch, optimizer, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0: