Add get_safe_torch_device in policies

This commit is contained in:
Simon Alibert
2024-03-20 18:38:55 +01:00
parent ec536ef0fa
commit 4631d36c05
6 changed files with 39 additions and 18 deletions

View File

@@ -6,6 +6,26 @@ import numpy as np
import torch
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
match cfg_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(cfg_device)
if log:
logging.warning(f"Using custom {cfg_device} device.")
return device
def set_seed(seed):
"""Set seed for reproducibility."""
random.seed(seed)