make unit tests pass

This commit is contained in:
Cadene
2024-04-23 21:39:39 +00:00
parent 42ed7bb670
commit 0660f71556
13 changed files with 79 additions and 38 deletions

View File

@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
# If you're doing something different, you will likely need to change at least some of the defaults.
cfg = DiffusionConfig()
# TODO(alexander-soare): Remove LR scheduler from the policy.
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps)
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
policy.train()
policy.to(device)