#!/usr/bin/env python # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import time from contextlib import nullcontext from copy import deepcopy from pathlib import Path import hydra import torch from omegaconf import DictConfig from torch.cuda.amp import GradScaler from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import PolicyWithUpdate from lerobot.common.policies.utils import get_device_from_parameters from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_logging, set_global_seed, ) from lerobot.scripts.eval import eval_policy def make_optimizer_and_scheduler(cfg, policy): if cfg.policy.name == "act": optimizer_params_dicts = [ { "params": [ p for n, p in policy.named_parameters() if not n.startswith("backbone") and p.requires_grad ] }, { "params": [ p for n, p in policy.named_parameters() if n.startswith("backbone") and p.requires_grad ], "lr": cfg.training.lr_backbone, }, ] optimizer = torch.optim.AdamW( optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay ) lr_scheduler = None elif cfg.policy.name == "diffusion": optimizer = torch.optim.Adam( policy.diffusion.parameters(), cfg.training.lr, cfg.training.adam_betas, cfg.training.adam_eps, cfg.training.adam_weight_decay, ) from diffusers.optimization import get_scheduler lr_scheduler = get_scheduler( cfg.training.lr_scheduler, optimizer=optimizer, num_warmup_steps=cfg.training.lr_warmup_steps, num_training_steps=cfg.training.offline_steps, ) elif policy.name == "tdmpc": optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr) lr_scheduler = None else: raise NotImplementedError() return optimizer, lr_scheduler def update_policy( policy, batch, optimizer, grad_clip_norm, grad_scaler: GradScaler, lr_scheduler=None, use_amp: bool = False, ): """Returns a dictionary of items for logging.""" start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() with torch.autocast(device_type=device.type) if use_amp else nullcontext(): output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) loss = output_dict["loss"] grad_scaler.scale(loss).backward() # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. grad_scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_( policy.parameters(), grad_clip_norm, error_if_nonfinite=False, ) # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, # although it still skips optimizer.step() if the gradients contain infs or NaNs. grad_scaler.step(optimizer) # Updates the scale for next iteration. grad_scaler.update() optimizer.zero_grad() if lr_scheduler is not None: lr_scheduler.step() if isinstance(policy, PolicyWithUpdate): # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). policy.update() info = { "loss": loss.item(), "grad_norm": float(grad_norm), "lr": optimizer.param_groups[0]["lr"], "update_s": time.perf_counter() - start_time, **{k: v for k, v in output_dict.items() if k != "loss"}, } return info @hydra.main(version_base="1.2", config_name="default", config_path="../configs") def train_cli(cfg: dict): train( cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, job_name=hydra.core.hydra_config.HydraConfig.get().job.name, ) def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() initialize(config_path=config_path) cfg = compose(config_name=config_name) train(cfg, out_dir=out_dir, job_name=job_name) def log_train_info(logger: Logger, info, step, cfg, dataset, is_offline): loss = info["loss"] grad_norm = info["grad_norm"] lr = info["lr"] update_s = info["update_s"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training f"smpl:{format_big_number(num_samples)}", # number of episodes seen during training f"ep:{format_big_number(num_episodes)}", # number of time all unique samples are seen f"epch:{num_epochs:.2f}", f"loss:{loss:.3f}", f"grdn:{grad_norm:.3f}", f"lr:{lr:0.1e}", # in seconds f"updt_s:{update_s:.3f}", ] logging.info(" ".join(log_items)) info["step"] = step info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs info["is_offline"] = is_offline logger.log_dict(info, step, mode="train") def log_eval_info(logger, info, step, cfg, dataset, is_offline): eval_s = info["eval_s"] avg_sum_reward = info["avg_sum_reward"] pc_success = info["pc_success"] # A sample is an (observation,action) pair, where observation and action # can be on multiple timestamps. In a batch, we have `batch_size`` number of samples. num_samples = (step + 1) * cfg.training.batch_size avg_samples_per_ep = dataset.num_samples / dataset.num_episodes num_episodes = num_samples / avg_samples_per_ep num_epochs = num_samples / dataset.num_samples log_items = [ f"step:{format_big_number(step)}", # number of samples seen during training f"smpl:{format_big_number(num_samples)}", # number of episodes seen during training f"ep:{format_big_number(num_episodes)}", # number of time all unique samples are seen f"epch:{num_epochs:.2f}", f"∑rwrd:{avg_sum_reward:.3f}", f"success:{pc_success:.1f}%", f"eval_s:{eval_s:.3f}", ] logging.info(" ".join(log_items)) info["step"] = step info["num_samples"] = num_samples info["num_episodes"] = num_episodes info["num_epochs"] = num_epochs info["is_offline"] = is_offline logger.log_dict(info, step, mode="eval") def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): if out_dir is None: raise NotImplementedError() if job_name is None: raise NotImplementedError() init_logging() if cfg.training.online_steps > 0: raise NotImplementedError("Online training is not implemented yet.") # Check device is available device = get_safe_torch_device(cfg.device, log=True) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True set_global_seed(cfg.seed) logging.info("make_dataset") offline_dataset = make_dataset(cfg) logging.info("make_env") eval_env = make_env(cfg) logging.info("make_policy") policy = make_policy(hydra_cfg=cfg, dataset_stats=offline_dataset.stats) # Create optimizer and scheduler # Temporary hack to move optimizer out of policy optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(enabled=cfg.use_amp) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) # log metrics to terminal and wandb logger = Logger(out_dir, job_name, cfg) log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") logging.info(f"{cfg.training.online_steps=}") logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})") logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step): if step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): eval_info = eval_policy( eval_env, policy, cfg.eval.n_episodes, video_dir=Path(out_dir) / "eval", max_episodes_rendered=4, start_seed=cfg.seed, ) log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_offline) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") if cfg.training.save_model and step % cfg.training.save_freq == 0: logging.info(f"Checkpoint policy after step {step}") # Note: Save with step as the identifier, and format it to have at least 6 digits but more if # needed (choose 6 as a minimum for consistency without being overkill). logger.save_model( policy, identifier=str(step).zfill( max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) ), ) logging.info("Resume training") # create dataloader for offline training dataloader = torch.utils.data.DataLoader( offline_dataset, num_workers=4, batch_size=cfg.training.batch_size, shuffle=True, pin_memory=device.type != "cpu", drop_last=False, ) dl_iter = cycle(dataloader) policy.train() is_offline = True for step in range(cfg.training.offline_steps): if step == 0: logging.info("Start offline training on a fixed dataset") batch = next(dl_iter) for key in batch: batch[key] = batch[key].to(device, non_blocking=True) train_info = update_policy( policy, batch, optimizer, cfg.training.grad_clip_norm, grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.use_amp, ) # TODO(rcadene): is it ok if step_t=0 = 0 and not 1 as previously done? if step % cfg.training.log_freq == 0: log_train_info(logger, train_info, step, cfg, offline_dataset, is_offline) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. evaluate_and_checkpoint_if_needed(step + 1) # create an empty online dataset similar to offline dataset online_dataset = deepcopy(offline_dataset) online_dataset.hf_dataset = {} online_dataset.episode_data_index = {} # create dataloader for online training concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) weights = [1.0] * len(concat_dataset) sampler = torch.utils.data.WeightedRandomSampler( weights, num_samples=len(concat_dataset), replacement=True ) dataloader = torch.utils.data.DataLoader( concat_dataset, num_workers=4, batch_size=cfg.training.batch_size, sampler=sampler, pin_memory=device.type != "cpu", drop_last=False, ) eval_env.close() logging.info("End of training") if __name__ == "__main__": train_cli()