Release cleanup (#132)

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Cadene <re.cadene@gmail.com>
This commit is contained in:
Simon Alibert
2024-05-06 03:03:14 +02:00
committed by GitHub
parent 6eaffbef1d
commit f5e76393eb
19 changed files with 312 additions and 237 deletions

View File

@@ -8,7 +8,6 @@ import hydra
import torch
from datasets import concatenate_datasets
from datasets.utils import disable_progress_bars, enable_progress_bars
from diffusers.optimization import get_scheduler
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
@@ -55,6 +54,8 @@ def make_optimizer_and_scheduler(cfg, policy):
cfg.training.adam_weight_decay,
)
assert cfg.training.online_steps == 0, "Diffusion Policy does not handle online training."
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
@@ -336,7 +337,7 @@ def train(cfg: dict, out_dir=None, job_name=None):
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def _maybe_eval_and_maybe_save(step):
def evaluate_and_checkpoint_if_needed(step):
if step % cfg.training.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
eval_info = eval_policy(
@@ -392,9 +393,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass in
# step + 1.
_maybe_eval_and_maybe_save(step + 1)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
@@ -460,9 +461,9 @@ def train(cfg: dict, out_dir=None, job_name=None):
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_offline)
# Note: _maybe_eval_and_maybe_save happens **after** the `step`th training update has completed, so we pass
# in step + 1.
_maybe_eval_and_maybe_save(step + 1)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
online_step += 1