feat(utils): add support for Intel XPU backend (#2233)
* feat: add support for Intel XPU backend in device selection * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: update is_amp_available to include xpu as a valid device * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * fix: remove unused return and add comments on fp64 fallback handling * fix(utils): return dtype in case xpu has fp64 --------- Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> Co-authored-by: Lim, Xiang Yang <xiang.yang.lim@intel.com> Co-authored-by: Lim Xiang Yang <xiangyang95@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com>
This commit is contained in:
@@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device:
|
||||
elif torch.backends.mps.is_available():
|
||||
logging.info("Metal backend detected, using mps.")
|
||||
return torch.device("mps")
|
||||
elif torch.xpu.is_available():
|
||||
logging.info("Intel XPU backend detected, using xpu.")
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||
return torch.device("cpu")
|
||||
@@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "xpu":
|
||||
assert torch.xpu.is_available()
|
||||
device = torch.device("xpu")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
@@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
device = device.type
|
||||
if device == "mps" and dtype == torch.float64:
|
||||
return torch.float32
|
||||
if device == "xpu" and dtype == torch.float64:
|
||||
if hasattr(torch.xpu, "get_device_capability"):
|
||||
device_capability = torch.xpu.get_device_capability()
|
||||
# NOTE: Some Intel XPU devices do not support double precision (FP64).
|
||||
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
|
||||
# when available; if False, we fall back to float32 for compatibility.
|
||||
if not device_capability.get("has_fp64", False):
|
||||
logging.warning(f"Device {device} does not support float64, using float32 instead.")
|
||||
return torch.float32
|
||||
else:
|
||||
logging.warning(
|
||||
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
|
||||
)
|
||||
return torch.float32
|
||||
return dtype
|
||||
else:
|
||||
return dtype
|
||||
|
||||
@@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool:
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
return torch.backends.mps.is_available()
|
||||
elif try_device == "xpu":
|
||||
return torch.xpu.is_available()
|
||||
elif try_device == "cpu":
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")
|
||||
|
||||
|
||||
def is_amp_available(device: str):
|
||||
if device in ["cuda", "cpu"]:
|
||||
if device in ["cuda", "xpu", "cpu"]:
|
||||
return True
|
||||
elif device == "mps":
|
||||
return False
|
||||
|
||||
@@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
|
||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||
for key, value in observation.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
|
||||
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
|
||||
|
||||
|
||||
def test_raw_observation_to_observation_deterministic():
|
||||
|
||||
Reference in New Issue
Block a user