diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 4447a1fc..46c15d92 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device: elif torch.backends.mps.is_available(): logging.info("Metal backend detected, using mps.") return torch.device("mps") + elif torch.xpu.is_available(): + logging.info("Intel XPU backend detected, using xpu.") + return torch.device("xpu") else: logging.warning("No accelerated backend detected. Using default cpu, this will be slow.") return torch.device("cpu") @@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: case "mps": assert torch.backends.mps.is_available() device = torch.device("mps") + case "xpu": + assert torch.xpu.is_available() + device = torch.device("xpu") case "cpu": device = torch.device("cpu") if log: @@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): device = device.type if device == "mps" and dtype == torch.float64: return torch.float32 + if device == "xpu" and dtype == torch.float64: + if hasattr(torch.xpu, "get_device_capability"): + device_capability = torch.xpu.get_device_capability() + # NOTE: Some Intel XPU devices do not support double precision (FP64). + # The `has_fp64` flag is returned by `torch.xpu.get_device_capability()` + # when available; if False, we fall back to float32 for compatibility. + if not device_capability.get("has_fp64", False): + logging.warning(f"Device {device} does not support float64, using float32 instead.") + return torch.float32 + else: + logging.warning( + f"Device {device} capability check failed. Assuming no support for float64, using float32 instead." + ) + return torch.float32 + return dtype else: return dtype @@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool: return torch.cuda.is_available() elif try_device == "mps": return torch.backends.mps.is_available() + elif try_device == "xpu": + return torch.xpu.is_available() elif try_device == "cpu": return True else: - raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.") + raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.") def is_amp_available(device: str): - if device in ["cuda", "cpu"]: + if device in ["cuda", "xpu", "cpu"]: return True elif device == "mps": return False diff --git a/tests/async_inference/test_helpers.py b/tests/async_inference/test_helpers.py index 1e2d1e31..a9e53200 100644 --- a/tests/async_inference/test_helpers.py +++ b/tests/async_inference/test_helpers.py @@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling(): # Check that all expected keys produce tensors (device placement handled by preprocessor later) for key, value in observation.items(): if isinstance(value, torch.Tensor): - assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device" + assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device" def test_raw_observation_to_observation_deterministic():