import logging import random from datetime import datetime import numpy as np import torch def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: match cfg_device: case "cuda": assert torch.cuda.is_available() device = torch.device("cuda") case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") case "cpu": device = torch.device("cpu") if log: logging.warning("Using CPU, this will be slow.") case _: device = torch.device(cfg_device) if log: logging.warning(f"Using custom {cfg_device} device.") return device def set_global_seed(seed): """Set seed for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def init_logging(): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) formatter = logging.Formatter() formatter.format = custom_format console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) def format_big_number(num): suffixes = ["", "K", "M", "B", "T", "Q"] divisor = 1000.0 for suffix in suffixes: if abs(num) < divisor: return f"{num:.0f}{suffix}" num /= divisor return num