forked from tangger/lerobot
move optimizer and scheduler outside policies
This commit is contained in:
@@ -127,25 +127,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
|||||||
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
self.action_head = nn.Linear(cfg.d_model, cfg.action_dim)
|
||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
self._create_optimizer()
|
|
||||||
|
|
||||||
def _create_optimizer(self):
|
|
||||||
optimizer_params_dicts = [
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [
|
|
||||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
|
||||||
],
|
|
||||||
"lr": self.cfg.lr_backbone,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
self.optimizer = torch.optim.AdamW(
|
|
||||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
def _reset_parameters(self):
|
||||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
import torchvision
|
import torchvision
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
from robomimic.models.base_nets import SpatialSoftmax
|
from robomimic.models.base_nets import SpatialSoftmax
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
@@ -70,25 +69,7 @@ class DiffusionPolicy(nn.Module):
|
|||||||
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
self.ema_diffusion = copy.deepcopy(self.diffusion)
|
||||||
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
self.ema = _EMA(cfg, model=self.ema_diffusion)
|
||||||
|
|
||||||
# TODO(alexander-soare): Move optimizer out of policy.
|
|
||||||
self.optimizer = torch.optim.Adam(
|
|
||||||
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(alexander-soare): Move LR scheduler out of policy.
|
|
||||||
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
|
||||||
self.global_step = 0
|
|
||||||
|
|
||||||
# configure lr scheduler
|
|
||||||
self.lr_scheduler = get_scheduler(
|
|
||||||
cfg.lr_scheduler,
|
|
||||||
optimizer=self.optimizer,
|
|
||||||
num_warmup_steps=cfg.lr_warmup_steps,
|
|
||||||
num_training_steps=lr_scheduler_num_training_steps,
|
|
||||||
# pytorch assumes stepping LRScheduler every epoch
|
|
||||||
# however huggingface diffusers steps it every batch
|
|
||||||
last_epoch=self.global_step - 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import hydra
|
|||||||
import torch
|
import torch
|
||||||
from datasets import concatenate_datasets
|
from datasets import concatenate_datasets
|
||||||
from datasets.utils import disable_progress_bars, enable_progress_bars
|
from datasets.utils import disable_progress_bars, enable_progress_bars
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.common.datasets.factory import make_dataset
|
from lerobot.common.datasets.factory import make_dataset
|
||||||
@@ -24,7 +25,7 @@ from lerobot.common.utils.utils import (
|
|||||||
from lerobot.scripts.eval import eval_policy
|
from lerobot.scripts.eval import eval_policy
|
||||||
|
|
||||||
|
|
||||||
def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
def update_diffusion(policy, batch: dict[str, Tensor], optimizer, lr_scheduler) -> dict:
|
||||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.diffusion.train()
|
policy.diffusion.train()
|
||||||
@@ -40,9 +41,9 @@ def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
|||||||
error_if_nonfinite=False,
|
error_if_nonfinite=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
policy.optimizer.step()
|
optimizer.step()
|
||||||
policy.optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
policy.lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
if policy.ema is not None:
|
if policy.ema is not None:
|
||||||
policy.ema.step(policy.diffusion)
|
policy.ema.step(policy.diffusion)
|
||||||
@@ -50,7 +51,7 @@ def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
|||||||
info = {
|
info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
"grad_norm": float(grad_norm),
|
"grad_norm": float(grad_norm),
|
||||||
"lr": policy.lr_scheduler.get_last_lr()[0],
|
"lr": lr_scheduler.get_last_lr()[0],
|
||||||
"update_s": time.time() - start_time,
|
"update_s": time.time() - start_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ def update_diffusion(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
def update_act(policy, batch: dict[str, Tensor], optimizer) -> dict:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
policy.train()
|
policy.train()
|
||||||
batch = policy.normalize_inputs(batch)
|
batch = policy.normalize_inputs(batch)
|
||||||
@@ -71,8 +72,8 @@ def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
|||||||
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
|
policy.parameters(), policy.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||||
)
|
)
|
||||||
|
|
||||||
policy.optimizer.step()
|
optimizer.step()
|
||||||
policy.optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
train_info = {
|
train_info = {
|
||||||
"loss": loss.item(),
|
"loss": loss.item(),
|
||||||
@@ -83,6 +84,8 @@ def update_act(self, policy, batch: dict[str, Tensor], **_) -> dict:
|
|||||||
|
|
||||||
return train_info
|
return train_info
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
@hydra.main(version_base=None, config_name="default", config_path="../configs")
|
||||||
def train_cli(cfg: dict):
|
def train_cli(cfg: dict):
|
||||||
train(
|
train(
|
||||||
@@ -295,6 +298,43 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
logging.info("make_policy")
|
logging.info("make_policy")
|
||||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||||
|
|
||||||
|
# Temporary hack to move optimizer out of policy
|
||||||
|
if isinstance(policy, ActPolicy):
|
||||||
|
optimizer_params_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": policy.cfg.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
optimizer_params_dicts, lr=policy.cfg.lr, weight_decay=policy.cfg.weight_decay
|
||||||
|
)
|
||||||
|
elif isinstance(policy, DiffusionPolicy):
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
policy.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
|
||||||
|
)
|
||||||
|
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
|
||||||
|
global_step = 0
|
||||||
|
# configure lr scheduler
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
cfg.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=cfg.lr_warmup_steps,
|
||||||
|
num_training_steps=cfg.offline_steps,
|
||||||
|
# pytorch assumes stepping LRScheduler every epoch
|
||||||
|
# however huggingface diffusers steps it every batch
|
||||||
|
last_epoch=global_step - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
|
|
||||||
@@ -355,11 +395,10 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
|||||||
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
batch[key] = batch[key].to(cfg.device, non_blocking=True)
|
||||||
|
|
||||||
# Temporary hack to move update outside of policy
|
# Temporary hack to move update outside of policy
|
||||||
if isinstance(policy, DiffusionPolicy):
|
if isinstance(policy, ActPolicy):
|
||||||
train_info = update_diffusion(policy, batch)
|
train_info = update_act(policy, batch, optimizer)
|
||||||
elif isinstance(policy, ActPolicy):
|
elif isinstance(policy, DiffusionPolicy):
|
||||||
train_info = update_act(policy, batch)
|
train_info = update_diffusion(policy, batch, optimizer, lr_scheduler)
|
||||||
|
|
||||||
|
|
||||||
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
|
||||||
if step % cfg.log_freq == 0:
|
if step % cfg.log_freq == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user