feat(utils): add support for Intel XPU backend (#2233)

* feat: add support for Intel XPU backend in device selection

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: update is_amp_available to include xpu as a valid device

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* Update src/lerobot/utils/utils.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>

* fix: remove unused return and add comments on fp64 fallback handling

* fix(utils): return dtype in case xpu has fp64

---------

Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com>
Co-authored-by: Lim, Xiang Yang <xiang.yang.lim@intel.com>
Co-authored-by: Lim Xiang Yang <xiangyang95@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jade Choghari <chogharijade@gmail.com>
This commit is contained in:
Steven Palma
2025-10-17 19:30:25 +02:00
committed by GitHub
parent 79137f58d1
commit 92b6254473
2 changed files with 26 additions and 3 deletions

View File

@@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device:
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using mps.")
return torch.device("mps")
elif torch.xpu.is_available():
logging.info("Intel XPU backend detected, using xpu.")
return torch.device("xpu")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
@@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
case "cpu":
device = torch.device("cpu")
if log:
@@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
if device == "xpu" and dtype == torch.float64:
if hasattr(torch.xpu, "get_device_capability"):
device_capability = torch.xpu.get_device_capability()
# NOTE: Some Intel XPU devices do not support double precision (FP64).
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
# when available; if False, we fall back to float32 for compatibility.
if not device_capability.get("has_fp64", False):
logging.warning(f"Device {device} does not support float64, using float32 instead.")
return torch.float32
else:
logging.warning(
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
)
return torch.float32
return dtype
else:
return dtype
@@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool:
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "xpu":
return torch.xpu.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")
def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
if device in ["cuda", "xpu", "cpu"]:
return True
elif device == "mps":
return False

View File

@@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
def test_raw_observation_to_observation_deterministic():