Add run on cpu-only compatibility

This commit is contained in:
Simon Alibert
2024-03-03 12:47:26 +01:00
parent 661bda45ea
commit b33ec5a630
5 changed files with 100 additions and 95 deletions

View File

@@ -1,4 +1,5 @@
import logging
import warnings
import hydra
import numpy as np
@@ -115,7 +116,11 @@ def train(cfg: dict, out_dir=None, job_name=None):
init_logging()
assert torch.cuda.is_available()
if cfg.device == "cuda":
assert torch.cuda.is_available()
else:
warnings.warn("Using CPU, this will be slow.", UserWarning, stacklevel=1)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_seed(cfg.seed)