Move normalization to policy for act and diffusion (#90)
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
@@ -62,7 +62,6 @@ while not done:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
torch.save(dataset.transform.transforms[-1].stats, output_directory / "stats.pth")
|
||||
|
||||
Reference in New Issue
Block a user