forked from tangger/lerobot
Use PytorchModelHubMixin to save models as safetensors (#125)
Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
@@ -265,7 +265,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
env = make_env(cfg, num_parallel_envs=cfg.eval.n_episodes)
|
||||
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(cfg, dataset_stats=offline_dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
|
||||
|
||||
# Create optimizer and scheduler
|
||||
# Temporary hack to move optimizer out of policy
|
||||
@@ -340,7 +340,14 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
if cfg.training.save_model and step % cfg.training.save_freq == 0:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logger.save_model(policy, identifier=step)
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
logger.save_model(
|
||||
policy,
|
||||
identifier=str(step).zfill(
|
||||
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
|
||||
),
|
||||
)
|
||||
logging.info("Resume training")
|
||||
|
||||
# create dataloader for offline training
|
||||
|
||||
Reference in New Issue
Block a user