Make sure targets are normalized too (#106)

This commit is contained in:
Alexander Soare
2024-04-26 11:18:39 +01:00
committed by GitHub
parent b980c5dd9e
commit 45f351c618
8 changed files with 116 additions and 92 deletions

View File

@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
if cfg is None:
cfg = DiffusionConfig()
self.cfg = cfg
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
self.unnormalize_outputs = Unnormalize(
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
)
# queues are populated during rollout of the policy, they contain the n latest observations and actions
self._queues = None
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
self.diffusion.train()
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)
loss = self.forward(batch)["loss"]
loss.backward()