forked from tangger/lerobot
Make sure targets are normalized too (#106)
This commit is contained in:
@@ -56,8 +56,11 @@ class DiffusionPolicy(nn.Module):
|
||||
if cfg is None:
|
||||
cfg = DiffusionConfig()
|
||||
self.cfg = cfg
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.normalize_input_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(cfg.output_shapes, cfg.unnormalize_output_modes, dataset_stats)
|
||||
self.normalize_inputs = Normalize(cfg.input_shapes, cfg.input_normalization_modes, dataset_stats)
|
||||
self.normalize_targets = Normalize(cfg.output_shapes, cfg.output_normalization_modes, dataset_stats)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
cfg.output_shapes, cfg.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
self._queues = None
|
||||
@@ -162,6 +165,7 @@ class DiffusionPolicy(nn.Module):
|
||||
self.diffusion.train()
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
loss = self.forward(batch)["loss"]
|
||||
loss.backward()
|
||||
|
||||
Reference in New Issue
Block a user