fix: support cuda:0, cuda:1 in string selection (#2256)

* fix

* update func 2

* update nightly

* fix quality

* ignore test_dynamixel
This commit is contained in:
Jade Choghari
2025-10-20 23:29:05 +02:00
committed by GitHub
parent 502fdc0630
commit 5f6f476f32
3 changed files with 20 additions and 22 deletions

View File

@@ -189,5 +189,5 @@ jobs:
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
- name: Run multi-GPU training tests
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/test_dynamixel.py
timeout-minutes: 10

View File

@@ -58,7 +58,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
device: str | None = None # cuda | cpu | mp
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
# automatic gradient scaling is used.
use_amp: bool = False

View File

@@ -57,25 +57,23 @@ def auto_select_torch_device() -> torch.device:
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
try_device = str(try_device)
match try_device:
case "cuda":
assert torch.cuda.is_available()
device = torch.device("cuda")
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
case "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
case _:
device = torch.device(try_device)
if log:
logging.warning(f"Using custom {try_device} device.")
if try_device.startswith("cuda"):
assert torch.cuda.is_available()
device = torch.device(try_device)
elif try_device == "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
elif try_device == "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
elif try_device == "cpu":
device = torch.device("cpu")
if log:
logging.warning("Using CPU, this will be slow.")
else:
device = torch.device(try_device)
if log:
logging.warning(f"Using custom {try_device} device.")
return device
@@ -108,7 +106,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
def is_torch_device_available(try_device: str) -> bool:
try_device = str(try_device) # Ensure try_device is a string
if try_device == "cuda":
if try_device.startswith("cuda"):
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()