Remove update method from the policy (#99)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
committed by
GitHub
parent
5b4fd8891d
commit
508bd92d03
@@ -5,7 +5,6 @@ The majority of changes here involve removing unused code, unifying naming, and
|
||||
"""
|
||||
|
||||
import math
|
||||
import time
|
||||
from collections import deque
|
||||
from itertools import chain
|
||||
from typing import Callable
|
||||
@@ -135,25 +134,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
self.action_head = nn.Linear(cfg.d_model, cfg.output_shapes["action"][0])
|
||||
|
||||
self._reset_parameters()
|
||||
self._create_optimizer()
|
||||
|
||||
def _create_optimizer(self):
|
||||
optimizer_params_dicts = [
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if not n.startswith("backbone") and p.requires_grad
|
||||
]
|
||||
},
|
||||
{
|
||||
"params": [
|
||||
p for n, p in self.named_parameters() if n.startswith("backbone") and p.requires_grad
|
||||
],
|
||||
"lr": self.cfg.lr_backbone,
|
||||
},
|
||||
]
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
optimizer_params_dicts, lr=self.cfg.lr, weight_decay=self.cfg.weight_decay
|
||||
)
|
||||
|
||||
def _reset_parameters(self):
|
||||
"""Xavier-uniform initialization of the transformer parameters as in the original code."""
|
||||
@@ -191,6 +171,8 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
def forward(self, batch, **_) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self._forward(batch)
|
||||
|
||||
l1_loss = (
|
||||
@@ -213,34 +195,6 @@ class ActionChunkingTransformerPolicy(nn.Module):
|
||||
|
||||
return loss_dict
|
||||
|
||||
def update(self, batch, **_) -> dict:
|
||||
"""Run the model in train mode, compute the loss, and do an optimization step."""
|
||||
start_time = time.time()
|
||||
self.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
loss_dict = self.forward(batch)
|
||||
# TODO(rcadene): self.unnormalize_outputs(out_dict)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info = {
|
||||
"loss": loss.item(),
|
||||
"grad_norm": float(grad_norm),
|
||||
"lr": self.cfg.lr,
|
||||
"update_s": time.time() - start_time,
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
def _stack_images(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Stacks all the images in a batch and puts them in a new key: "observation.images".
|
||||
|
||||
Reference in New Issue
Block a user