Remove update method from the policy (#99)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Quentin Gallouédec
2024-04-29 12:27:58 +02:00
committed by GitHub
parent 5b4fd8891d
commit 508bd92d03
8 changed files with 84 additions and 122 deletions

View File

@@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
"""
import math
import time
from collections import deque
from itertools import chain
from typing import Callable
@@ -135,25 +134,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
self._reset_parameters()
self._create_optimizer()
def _create_optimizer(self):
optimizer_params_dicts = [
{
"params": [
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
]
},
{
"params": [
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
],
"lr": self.cfg.lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
)
def _reset_parameters(self):
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
@@ -191,6 +171,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
def forward(self, batch, **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
l1_loss = (
@@ -213,34 +195,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
return loss_dict
def update(self, batch, **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.train()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss_dict = self.forward(batch)
# TODO(rcadene): self.unnormalize_outputs(out_dict)
loss = loss_dict["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
)
self.optimizer.step()
self.optimizer.zero_grad()
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.cfg.lr,
"update_s": time.time() - start_time,
}
return info
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Stacks all the images in a batch and puts them in a new key: "observation.images".

View File

@@ -11,7 +11,6 @@ TODO(alexander-soare):
import copy
import logging
import math
import time
from collections import deque
from typing import Callable
@@ -19,7 +18,6 @@ import einops
import torch
import torch.nn.functional as F # noqa: N812
import torchvision
from diffusers.optimization import get_scheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from robomimic.models.base_nets import SpatialSoftmax
from torch import Tensor, nn
@@ -74,26 +72,6 @@ class DiffusionPolicy(nn.Module):
self.ema_diffusion = copy.deepcopy(self.diffusion)
self.ema = _EMA(cfg, model=self.ema_diffusion)
# TODO(alexander-soare): Move optimizer out of policy.
self.optimizer = torch.optim.Adam(
self.diffusion.parameters(), cfg.lr, cfg.adam_betas, cfg.adam_eps, cfg.adam_weight_decay
)
# TODO(alexander-soare): Move LR scheduler out of policy.
# TODO(rcadene): modify lr scheduler so that it doesn't depend on epochs but steps
self.global_step = 0
# configure lr scheduler
self.lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=self.optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=lr_scheduler_num_training_steps,
# pytorch assumes stepping LRScheduler every epoch
# however huggingface diffusers steps it every batch
last_epoch=self.global_step - 1,
)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
@@ -155,44 +133,10 @@ class DiffusionPolicy(nn.Module):
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss = self.forward(batch)["loss"]
loss.backward()
# TODO(rcadene): self.unnormalize_outputs(out_dict)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.diffusion.parameters(),
self.cfg.grad_clip_norm,
error_if_nonfinite=False,
)
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
if self.ema is not None:
self.ema.step(self.diffusion)
info = {
"loss": loss.item(),
"grad_norm": float(grad_norm),
"lr": self.lr_scheduler.get_last_lr()[0],
"update_s": time.time() - start_time,
}
return info
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def save(self, fp):
torch.save(self.state_dict(), fp)

View File

@@ -36,10 +36,3 @@ class Policy(Protocol):
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def update(self, batch):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
disappear.
"""