forked from tangger/lerobot
feat(utils): add support for Intel XPU backend (#2233)
* feat: add support for Intel XPU backend in device selection * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: update is_amp_available to include xpu as a valid device * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> * fix: remove unused return and add comments on fp64 fallback handling * fix(utils): return dtype in case xpu has fp64 --------- Signed-off-by: Lim Xiang Yang <xiangyang95@gmail.com> Co-authored-by: Lim, Xiang Yang <xiang.yang.lim@intel.com> Co-authored-by: Lim Xiang Yang <xiangyang95@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jade Choghari <chogharijade@gmail.com>
This commit is contained in:
@@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device:
|
|||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
logging.info("Metal backend detected, using mps.")
|
logging.info("Metal backend detected, using mps.")
|
||||||
return torch.device("mps")
|
return torch.device("mps")
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
logging.info("Intel XPU backend detected, using xpu.")
|
||||||
|
return torch.device("xpu")
|
||||||
else:
|
else:
|
||||||
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
@@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
|||||||
case "mps":
|
case "mps":
|
||||||
assert torch.backends.mps.is_available()
|
assert torch.backends.mps.is_available()
|
||||||
device = torch.device("mps")
|
device = torch.device("mps")
|
||||||
|
case "xpu":
|
||||||
|
assert torch.xpu.is_available()
|
||||||
|
device = torch.device("xpu")
|
||||||
case "cpu":
|
case "cpu":
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if log:
|
if log:
|
||||||
@@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
|||||||
device = device.type
|
device = device.type
|
||||||
if device == "mps" and dtype == torch.float64:
|
if device == "mps" and dtype == torch.float64:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
if device == "xpu" and dtype == torch.float64:
|
||||||
|
if hasattr(torch.xpu, "get_device_capability"):
|
||||||
|
device_capability = torch.xpu.get_device_capability()
|
||||||
|
# NOTE: Some Intel XPU devices do not support double precision (FP64).
|
||||||
|
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
|
||||||
|
# when available; if False, we fall back to float32 for compatibility.
|
||||||
|
if not device_capability.get("has_fp64", False):
|
||||||
|
logging.warning(f"Device {device} does not support float64, using float32 instead.")
|
||||||
|
return torch.float32
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
|
||||||
|
)
|
||||||
|
return torch.float32
|
||||||
|
return dtype
|
||||||
else:
|
else:
|
||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
@@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool:
|
|||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
elif try_device == "mps":
|
elif try_device == "mps":
|
||||||
return torch.backends.mps.is_available()
|
return torch.backends.mps.is_available()
|
||||||
|
elif try_device == "xpu":
|
||||||
|
return torch.xpu.is_available()
|
||||||
elif try_device == "cpu":
|
elif try_device == "cpu":
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
|
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")
|
||||||
|
|
||||||
|
|
||||||
def is_amp_available(device: str):
|
def is_amp_available(device: str):
|
||||||
if device in ["cuda", "cpu"]:
|
if device in ["cuda", "xpu", "cpu"]:
|
||||||
return True
|
return True
|
||||||
elif device == "mps":
|
elif device == "mps":
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
|
|||||||
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
|
||||||
for key, value in observation.items():
|
for key, value in observation.items():
|
||||||
if isinstance(value, torch.Tensor):
|
if isinstance(value, torch.Tensor):
|
||||||
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
|
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
|
||||||
|
|
||||||
|
|
||||||
def test_raw_observation_to_observation_deterministic():
|
def test_raw_observation_to_observation_deterministic():
|
||||||
|
|||||||
Reference in New Issue
Block a user