import logging import os.path as osp import random from datetime import datetime from pathlib import Path import hydra import numpy as np import torch from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: match cfg_device: case "cuda": assert torch.cuda.is_available() device = torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") case "cpu": device = torch.device("cpu") if log: logging.warning("Using CPU, this will be slow.") case _: device = torch.device(cfg_device) if log: logging.warning(f"Using custom {cfg_device} device.") return device def set_global_seed(seed): """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def init_logging(): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) formatter = logging.Formatter() formatter.format = custom_format console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) def format_big_number(num): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: return f"{num:.0f}{suffix}" num /= divisor return num def _relative_path_between(path1: Path, path2: Path) -> Path: """Returns path1 relative to path2.""" path1 = path1.absolute() path2 = path2.absolute() try: return path1.relative_to(path2) except ValueError: # most likely because path1 is not a subpath of path2 common_parts = Path(osp.commonpath([path1, path2])).parts return Path( "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) ) def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: """Initialize a Hydra config given only the path to the relevant config file. For config resolution, it is assumed that the config file's parent is the Hydra config dir. """ # TODO(alexander-soare): Resolve configs without Hydra initialization. hydra.core.global_hydra.GlobalHydra.instance().clear() # Hydra needs a path relative to this file. hydra.initialize( str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) ) cfg = hydra.compose(Path(config_path).stem, overrides) return cfg def print_cuda_memory_usage(): import gc gc.collect() # Also clear the cache if you want to fully release the memory torch.cuda.empty_cache() print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2))