forked from tangger/lerobot
merge update functions
This commit is contained in:
@@ -9,7 +9,6 @@ import torch
|
|||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
from lerobot.common.datasets.utils import cycle
|
from lerobot.common.datasets.utils import cycle
|
||||||
@@ -25,67 +24,47 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
def update_diffusion(policy, batch: dict[str, Tensor], optimizer, lr_scheduler) -> dict:
|
def update_policy(policy, batch, optimizer, lr_scheduler=None):
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Diffusion
|
||||||
policy.diffusion.train()
|
policy.diffusion.train()
|
||||||
|
# Act
|
||||||
|
policy.train()
|
||||||
|
|
||||||
batch = policy.normalize_inputs(batch)
|
batch = policy.normalize_inputs(batch)
|
||||||
loss = policy.forward(batch)["loss"]
|
|
||||||
|
output_dict = policy.forward(batch)
|
||||||
|
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||||
|
loss = output_dict["loss"]
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
# Diffusion
|
||||||
|
model = policy.diffusion if hasattr(policy, "diffusion") else policy # TODO: hacky, remove this line
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
policy.diffusion.parameters(),
|
model.parameters(),
|
||||||
policy.cfg.grad_clip_norm,
|
policy.cfg.grad_clip_norm,
|
||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
if lr_scheduler is not None:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if policy.ema is not None:
|
if hasattr(policy, "ema") and policy.ema is not None:
|
||||||
policy.ema.step(policy.diffusion)
|
policy.ema.step(model)
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": lr_scheduler.get_last_lr()[0],
|
"lr": policy.cfg.lr if lr_scheduler is None else lr_scheduler.get_last_lr()[0],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def update_act(policy, batch: dict[str, Tensor], optimizer) -> dict:
|
|
||||||
start_time = time.time()
|
|
||||||
policy.train()
|
|
||||||
batch = policy.normalize_inputs(batch)
|
|
||||||
loss_dict = policy.forward(batch)
|
|
||||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
|
||||||
loss = loss_dict["loss"]
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
train_info = {
|
|
||||||
"loss": loss.item(),
|
|
||||||
"grad_norm": float(grad_norm),
|
|
||||||
"lr": policy.cfg.lr,
|
|
||||||
"update_s": time.time() - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
return train_info
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
train(
|
train(
|
||||||
@@ -316,6 +295,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
|
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
|
||||||
)
|
)
|
||||||
|
lr_scheduler = None
|
||||||
elif isinstance(policy, DiffusionPolicy):
|
elif isinstance(policy, DiffusionPolicy):
|
||||||
optimizer = torch.optim.Adam(
|
optimizer = torch.optim.Adam(
|
||||||
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||||
@@ -394,11 +374,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
for key in batch:
|
for key in batch:
|
||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
# Temporary hack to move update outside of policy
|
train_info = update_policy(policy, batch, optimizer, lr_scheduler)
|
||||||
if isinstance(policy, ActPolicy):
|
|
||||||
train_info = update_act(policy, batch, optimizer)
|
|
||||||
elif isinstance(policy, DiffusionPolicy):
|
|
||||||
train_info = update_diffusion(policy, batch, optimizer, lr_scheduler)
|
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user