Use PytorchModelHubMixin to save models as safetensors (#125)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:17:18 +01:00
committed by GitHub
parent 01d5490d44
commit a4891095e4
18 changed files with 556 additions and 527 deletions

View File

@@ -265,7 +265,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
logging.info("make_policy")
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
@@ -340,7 +340,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
if cfg.training.save_model and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
logger.save_model(policy, identifier=step)
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_model(
policy,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
)
logging.info("Resume training")
# create dataloader for offline training