diff --git a/lerobot/common/utils/logging_utils.py b/lerobot/common/utils/logging_utils.py index b99c348f7..192f646bd 100644 --- a/lerobot/common/utils/logging_utils.py +++ b/lerobot/common/utils/logging_utils.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Callable from lerobot.common.utils.utils import format_big_number @@ -93,12 +93,14 @@ class MetricsTracker: num_episodes: int, metrics: dict[str, AverageMeter], initial_step: int = 0, + accelerator: Callable = None, ): self.__dict__.update({k: None for k in self.__keys__}) self._batch_size = batch_size self._num_frames = num_frames self._avg_samples_per_ep = num_frames / num_episodes self.metrics = metrics + self.accelerator = accelerator self.steps = initial_step # A sample is an (observation,action) pair, where observation and action @@ -128,7 +130,7 @@ class MetricsTracker: Updates metrics that depend on 'step' for one step. """ self.steps += 1 - self.samples += self._batch_size + self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1) self.episodes = self.samples / self._avg_samples_per_ep self.epochs = self.samples / self._num_frames diff --git a/lerobot/common/utils/random_utils.py b/lerobot/common/utils/random_utils.py index 3d9bf4dd8..e5d59cd46 100644 --- a/lerobot/common/utils/random_utils.py +++ b/lerobot/common/utils/random_utils.py @@ -16,7 +16,7 @@ import random from contextlib import contextmanager from pathlib import Path -from typing import Any, Generator +from typing import Any, Generator, Callable import numpy as np import torch @@ -163,14 +163,16 @@ def set_rng_state(random_state_dict: dict[str, Any]): torch.cuda.random.set_rng_state(random_state_dict["torch_cuda_random_state"]) -def set_seed(seed) -> None: +def set_seed(seed, accelerator: Callable = None) -> None: """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + if accelerator: + from accelerate.utils import set_seed + set_seed(seed) @contextmanager def seeded_context(seed: int) -> Generator[None, None, None]: diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index d0c12b30c..d316c0ac2 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -20,10 +20,10 @@ import platform from copy import copy from datetime import datetime, timezone from pathlib import Path - +from typing import Callable import numpy as np import torch - +from typing import Any def none_or_int(value): if value == "None": @@ -50,12 +50,12 @@ def auto_select_torch_device() -> torch.device: return torch.device("cpu") -def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: +def get_safe_torch_device(try_device: str, log: bool = False, accelerator: Callable = None) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" match try_device: case "cuda": assert torch.cuda.is_available() - device = torch.device("cuda") + device = accelerator.device if accelerator else torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") @@ -103,7 +103,7 @@ def is_amp_available(device: str): raise ValueError(f"Unknown device '{device}.") -def init_logging(): +def init_logging(accelerator: Callable = None): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" @@ -120,7 +120,10 @@ def init_logging(): console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) - + if accelerator is not None and not accelerator.is_main_process: + # Disable duplicate logging on non-main processes + logging.info(f"Setting logging level on non-main process {accelerator.process_index} to WARNING.") + logging.getLogger().setLevel(logging.WARNING) def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -216,3 +219,18 @@ def is_valid_numpy_dtype_string(dtype_str: str) -> bool: except TypeError: # If a TypeError is raised, the string is not a valid dtype return False + +def is_launched_with_accelerate() -> bool: + return "ACCELERATE_MIXED_PRECISION" in os.environ + +def get_accelerate_config(accelerator: Callable = None) -> dict[str, Any]: + config = {} + if not accelerator: + return config + config["num_processes"] = accelerator.num_processes + config["device"] = str(accelerator.device) + config["distributed_type"] = str(accelerator.distributed_type) + config["mixed_precision"] = accelerator.mixed_precision + config["gradient_accumulation_steps"] = accelerator.gradient_accumulation_steps + + return config \ No newline at end of file diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index f3c57fe28..47afa62fb 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -17,7 +17,7 @@ import logging import time from contextlib import nullcontext from pprint import pformat -from typing import Any +from typing import Any, Callable import torch from termcolor import colored @@ -46,6 +46,8 @@ from lerobot.common.utils.utils import ( get_safe_torch_device, has_method, init_logging, + get_accelerate_config, + is_launched_with_accelerate ) from lerobot.common.utils.wandb_utils import WandBLogger from lerobot.configs import parser @@ -63,30 +65,41 @@ def update_policy( lr_scheduler=None, use_amp: bool = False, lock=None, + accelerator: Callable = None, ) -> tuple[MetricsTracker, dict]: start_time = time.perf_counter() device = get_device_from_parameters(policy) policy.train() - with torch.autocast(device_type=device.type) if use_amp else nullcontext(): + with torch.autocast(device_type=device.type) if use_amp and accelerator is None else nullcontext(): loss, output_dict = policy.forward(batch) # TODO(rcadene): policy.unnormalize_outputs(out_dict) - grad_scaler.scale(loss).backward() - # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. - grad_scaler.unscale_(optimizer) + if accelerator: + accelerator.backward(loss) + accelerator.unscale_gradients(optimizer=optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) + optimizer.step() + else: + grad_scaler.scale(loss).backward() + # Unscale the graident of the optimzer's assigned params in-place **prior to gradient clipping**. + grad_scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - policy.parameters(), - grad_clip_norm, - error_if_nonfinite=False, - ) + grad_norm = torch.nn.utils.clip_grad_norm_( + policy.parameters(), + grad_clip_norm, + error_if_nonfinite=False, + ) - # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, - # although it still skips optimizer.step() if the gradients contain infs or NaNs. - with lock if lock is not None else nullcontext(): - grad_scaler.step(optimizer) - # Updates the scale for next iteration. - grad_scaler.update() + # Optimizer's gradients are already unscaled, so scaler.step does not unscale them, + # although it still skips optimizer.step() if the gradients contain infs or NaNs. + with lock if lock is not None else nullcontext(): + grad_scaler.step(optimizer) + # Updates the scale for next iteration. + grad_scaler.update() optimizer.zero_grad() @@ -94,9 +107,13 @@ def update_policy( if lr_scheduler is not None: lr_scheduler.step() - if has_method(policy, "update"): - # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). - policy.update() + if accelerator: + if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): # FIXME(mshukor): avoid accelerator.unwrap_model ? + accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() + else: + if has_method(policy, "update"): + # To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC). + policy.update() train_metrics.loss = loss.item() train_metrics.grad_norm = grad_norm.item() @@ -106,10 +123,14 @@ def update_policy( @parser.wrap() -def train(cfg: TrainPipelineConfig): +def train(cfg: TrainPipelineConfig, accelerator: Callable = None): cfg.validate() logging.info(pformat(cfg.to_dict())) + if accelerator and not accelerator.is_main_process: + # Disable logging on non-main processes. + cfg.wandb.enable = False + if cfg.wandb.enable and cfg.wandb.project: wandb_logger = WandBLogger(cfg) else: @@ -117,10 +138,10 @@ def train(cfg: TrainPipelineConfig): logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) if cfg.seed is not None: - set_seed(cfg.seed) + set_seed(cfg.seed, accelerator=accelerator) # Check device is available - device = get_safe_torch_device(cfg.device, log=True) + device = get_safe_torch_device(cfg.device, log=True, accelerator=accelerator) torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -141,7 +162,7 @@ def train(cfg: TrainPipelineConfig): device=device, ds_meta=dataset.meta, ) - + policy.to(device) logging.info("Creating optimizer and scheduler") optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) grad_scaler = GradScaler(device, enabled=cfg.use_amp) @@ -184,6 +205,10 @@ def train(cfg: TrainPipelineConfig): pin_memory=device.type != "cpu", drop_last=False, ) + if accelerator: + policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( + policy, optimizer, dataloader, lr_scheduler + ) dl_iter = cycle(dataloader) policy.train() @@ -197,7 +222,7 @@ def train(cfg: TrainPipelineConfig): } train_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step + cfg.batch_size, dataset.num_frames, dataset.num_episodes, train_metrics, initial_step=step, accelerator=accelerator ) logging.info("Start offline training on a fixed dataset") @@ -219,6 +244,7 @@ def train(cfg: TrainPipelineConfig): grad_scaler=grad_scaler, lr_scheduler=lr_scheduler, use_amp=cfg.use_amp, + accelerator=accelerator, ) # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we @@ -238,21 +264,26 @@ def train(cfg: TrainPipelineConfig): wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() - if cfg.save_checkpoint and is_saving_step: + if cfg.save_checkpoint and is_saving_step and (not accelerator or accelerator.is_main_process): logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + save_checkpoint(checkpoint_dir, step, cfg, policy if not accelerator else accelerator.unwrap_model(policy), optimizer, lr_scheduler) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) + if accelerator: + accelerator.wait_for_everyone() if cfg.env and is_eval_step: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if cfg.use_amp and not accelerator else nullcontext(), + ): eval_info = eval_policy( eval_env, - policy, + policy if not accelerator else accelerator.unwrap_model(policy), cfg.eval.n_episodes, videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}", max_episodes_rendered=4, @@ -265,7 +296,7 @@ def train(cfg: TrainPipelineConfig): "eval_s": AverageMeter("eval_s", ":.3f"), } eval_tracker = MetricsTracker( - cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step + cfg.batch_size, dataset.num_frames, dataset.num_episodes, eval_metrics, initial_step=step, accelerator=None ) eval_tracker.eval_s = eval_info["aggregated"].pop("eval_s") eval_tracker.avg_sum_reward = eval_info["aggregated"].pop("avg_sum_reward") @@ -283,4 +314,11 @@ def train(cfg: TrainPipelineConfig): if __name__ == "__main__": init_logging() - train() + if is_launched_with_accelerate(): + import accelerate + # We set step_scheduler_with_optimizer False to prevent accelerate from + # adjusting the lr_scheduler steps based on the num_processes + accelerator = accelerate.Accelerator(step_scheduler_with_optimizer=False) + train(accelerator=accelerator) + else: + train()