Remove update method from the policy (#99)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5b4fd8891d
commit
508bd92d03
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
@@ -7,6 +8,7 @@ import hydra
|
||||
import torch
|
||||
from datasets import concatenate_datasets
|
||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
@@ -22,6 +24,37 @@ from lerobot.common.utils.utils import (
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def update_policy(policy, batch, optimizer, grad_clip_norm, lr_scheduler=None):
|
||||
start_time = time.time()
|
||||
policy.train()
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
loss = output_dict["loss"]
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
policy.parameters(),
|
||||
grad_clip_norm,
|
||||
error_if_nonfinite=False,
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step()
|
||||
|
||||
if hasattr(policy, "ema") and policy.ema is not None:
|
||||
policy.ema.step(policy.diffusion)
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": optimizer.param_groups[0]['lr'],
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
@@ -234,6 +267,36 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
if cfg.policy.name == "act":
|
||||
optimizer_params_dicts = [
|
||||
{"params": [p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad]},
|
||||
{
|
||||
"params": [p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad],
|
||||
"lr": cfg.policy.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optimizer_params_dicts, lr=cfg.policy.lr, weight_decay=cfg.policy.weight_decay)
|
||||
lr_scheduler = None
|
||||
elif cfg.policy.name == "diffusion":
|
||||
optimizer = torch.optim.Adam(
|
||||
policy.diffusion.parameters(), cfg.policy.lr, cfg.policy.adam_betas, cfg.policy.adam_eps, cfg.policy.adam_weight_decay
|
||||
)
|
||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||
# configure lr scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
cfg.policy.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.policy.lr_warmup_steps,
|
||||
num_training_steps=cfg.offline_steps,
|
||||
# pytorch assumes stepping LRScheduler every epoch
|
||||
# however huggingface diffusers steps it every batch
|
||||
last_epoch=-1,
|
||||
)
|
||||
elif policy.name == "tdmpc":
|
||||
raise NotImplementedError("TD-MPC not implemented yet.")
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
@@ -292,7 +355,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy.update(batch, step=step)
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||
|
||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||
if step % cfg.log_freq == 0:
|
||||
@@ -358,7 +421,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
for key in batch:
|
||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||
|
||||
train_info = policy.update(batch, step)
|
||||
train_info = update_policy(policy, batch, optimizer, cfg.policy.grad_clip_norm, lr_scheduler)
|
||||
|
||||
if step % cfg.log_freq == 0:
|
||||
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
|
||||
|
||||
Reference in New Issue
Block a user