Add resume training (#205)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-05-28 12:04:23 +01:00
committed by GitHub
parent 7ec76ee235
commit e3b9f1c19b
15 changed files with 486 additions and 191 deletions

View File

@@ -18,13 +18,16 @@ import time
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from pprint import pformat
import hydra
import torch
from omegaconf import DictConfig
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch.cuda.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
@@ -34,6 +37,7 @@ from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
@@ -140,24 +144,6 @@ def update_policy(
return info
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline):
loss = info["loss"]
grad_norm = info["grad_norm"]
@@ -237,15 +223,60 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
init_logging()
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists."
)
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
if cfg.training.online_steps > 0:
raise NotImplementedError("Online training is not implemented yet.")
set_global_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
logging.info("make_dataset")
offline_dataset = make_dataset(cfg)
@@ -254,19 +285,25 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
eval_env = make_env(cfg)
logging.info("make_policy")
policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats)
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step = logger.load_last_training_state(optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
# log metrics to terminal and wandb
logger = Logger(out_dir, job_name, cfg)
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
@@ -294,12 +331,15 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_model and step % cfg.training.save_freq == 0:
if cfg.training.save_checkpoint and step % cfg.training.save_freq == 0:
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_model(
logger.save_checkpont(
step,
policy,
optimizer,
lr_scheduler,
identifier=str(step).zfill(
max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
),
@@ -319,7 +359,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy.train()
is_offline = True
for step in range(cfg.training.offline_steps):
for _ in range(step, cfg.training.offline_steps):
if step == 0:
logging.info("Start offline training on a fixed dataset")
batch = next(dl_iter)
@@ -337,7 +377,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
use_amp=cfg.use_amp,
)
# TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done?
if step % cfg.training.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline)
@@ -345,6 +384,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# so we pass in step + 1.
evaluate_and_checkpoint_if_needed(step + 1)
step += 1
# create an empty online dataset similar to offline dataset
online_dataset = deepcopy(offline_dataset)
online_dataset.hf_dataset = {}
@@ -369,5 +410,23 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train_cli()