This commit is contained in:
Alexander Soare
2024-04-17 16:21:37 +01:00
parent 63e5ec6483
commit 2298ddf226
3 changed files with 26 additions and 22 deletions

View File

@@ -43,12 +43,12 @@ class DiffusionPolicy(nn.Module):
name = "diffusion"
def __init__(self, cfg: DiffusionConfig | None, lr_scheduler_num_training_steps: int = 0):
super().__init__()
"""
Args:
cfg: Policy configuration class instance or None, in which case the default instantiation of the
configuration class is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
assert lr_scheduler_num_training_steps > 0
if cfg is None:
@@ -140,12 +140,18 @@ class DiffusionPolicy(nn.Module):
action = self._queues["action"].popleft()
return action
def forward(self, batch, **_):
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
def update(self, batch: dict[str, Tensor], **_) -> dict:
"""Run the model in train mode, compute the loss, and do an optimization step."""
start_time = time.time()
self.diffusion.train()
loss = self.diffusion.compute_loss(batch)
loss = self.forward(batch)["loss"]
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(