Co-authored-by: Simon Alibert <simon.alibert@huggingface.co>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
Co-authored-by: Pablo <pablo.montalvo.leroux@gmail.com>
This commit is contained in:
Remi
2025-02-04 18:01:04 +01:00
committed by GitHub
parent dd974529cf
commit 638d411cd3
26 changed files with 2365 additions and 92 deletions

View File

@@ -74,6 +74,18 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
return device
def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
"""
mps is currently not compatible with float64
"""
if isinstance(device, torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
else:
return dtype
def is_torch_device_available(try_device: str) -> bool:
if try_device == "cuda":
return torch.cuda.is_available()