Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -19,14 +19,13 @@ import os.path as osp
import platform
import random
from contextlib import contextmanager
from copy import copy
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def none_or_int(value):
@@ -41,9 +40,22 @@ def inside_slurm():
return "SLURM_JOB_ID" in os.environ
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
def auto_select_torch_device() -> torch.device:
"""Tries to select automatically a torch device."""
if torch.cuda.is_available():
logging.info("Cuda backend detected, using cuda.")
return torch.device("cuda")
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using cuda.")
return torch.device("mps")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
@@ -55,13 +67,33 @@ def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(cfg_device)
device = torch.device(try_device)
if log:
logging.warning(f"Using custom {cfg_device} device.")
logging.warning(f"Using custom {try_device} device.")
return device
def is_torch_device_available(try_device: str) -> bool:
if try_device == "cuda":
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device '{try_device}.")
def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
return True
elif device == "mps":
return False
else:
raise ValueError(f"Unknown device '{device}.")
def get_global_random_state() -> dict[str, Any]:
"""Get the random state for `random`, `numpy`, and `torch`."""
random_state_dict = {
@@ -159,22 +191,6 @@ def _relative_path_between(path1: Path, path2: Path) -> Path:
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# TODO(alexander-soare): Resolve configs without Hydra initialization.
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)),
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg
def print_cuda_memory_usage():
"""Use this function to locate and debug memory leak."""
import gc
@@ -217,3 +233,17 @@ def log_say(text, play_sounds, blocking=False):
if play_sounds:
say(text, blocking)
def get_channel_first_image_shape(image_shape: tuple) -> tuple:
shape = copy(image_shape)
if shape[2] < shape[0] and shape[2] < shape[1]: # (h, w, c) -> (c, h, w)
shape = (shape[2], shape[0], shape[1])
elif not (shape[0] < shape[1] and shape[0] < shape[2]):
raise ValueError(image_shape)
return shape
def has_method(cls: object, method_name: str):
return hasattr(cls, method_name) and callable(getattr(cls, method_name))