forked from tangger/lerobot
Use PytorchModelHubMixin to save models as safetensors (#125)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -34,7 +34,7 @@ dataset = make_dataset(hydra_cfg)
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
# TODO(alexander-soare): Remove LR scheduler from the policy.
|
||||
policy = DiffusionPolicy(cfg, lr_scheduler_num_training_steps=training_steps, dataset_stats=dataset.stats)
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
@@ -69,6 +69,7 @@ while not done:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy and configuration for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
# Save the policy.
|
||||
policy.save_pretrained(output_directory)
|
||||
# Save the Hydra configuration so we have the environment configuration for eval.
|
||||
OmegaConf.save(hydra_cfg, output_directory / "config.yaml")
|
||||
|
||||
Reference in New Issue
Block a user