Remove update method from the policy (#99)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Quentin Gallouédec
2024-04-29 12:27:58 +02:00
committed by GitHub
parent 5b4fd8891d
commit 508bd92d03
8 changed files with 84 additions and 122 deletions

View File

@@ -1,4 +1,5 @@
import logging
import time
from copy import deepcopy
from pathlib import Path
@@ -7,6 +8,7 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -22,6 +24,37 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.eval import eval_policy
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
start_time = time.time()
policy.train()
output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
loss = output_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
policy.parameters(),
grad_clip_norm,
error_if_nonfinite=False,
)
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step()
if hasattr(policy, "ema") and policy.ema is not None:
policy.ema.step(policy.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": optimizer.param_groups[0]['lr'],
"update_s": time.time() - start_time,
}
return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
@@ -234,6 +267,36 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
if cfg.policy.name == "act":
optimizer_params_dicts = [
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
{
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
"lr": cfg.policy.lr_backbone,
},
]
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(), cfg.policy.lr, cfg.policy.adam_betas, cfg.policy.adam_eps, cfg.policy.adam_weight_decay
)
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
# configure lr scheduler
lr_scheduler = get_scheduler(
cfg.policy.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.policy.lr_warmup_steps,
num_training_steps=cfg.offline_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=-1,
)
elif policy.name == "tdmpc":
raise NotImplementedError("TD-MPC not implemented yet.")
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
@@ -292,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step=step)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.log_freq == 0:
@@ -358,7 +421,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
train_info = policy.update(batch, step)
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)