From 5f6f476f32c3e4ccfc5b2bda0e7f328c796458aa Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 20 Oct 2025 23:29:05 +0200 Subject: [PATCH] fix: support cuda:0, cuda:1 in string selection (#2256) * fix * update func 2 * update nightly * fix quality * ignore test_dynamixel --- .github/workflows/nightly.yml | 2 +- src/lerobot/configs/policies.py | 2 +- src/lerobot/utils/utils.py | 38 ++++++++++++++++----------------- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index f9fa0259..4904ed15 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -189,5 +189,5 @@ jobs: python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')" - name: Run multi-GPU training tests - run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3 + run: pytest tests -vv --maxfail=10 --ignore=tests/motors/test_dynamixel.py timeout-minutes: 10 diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index b1cc19a4..0ecfa169 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -58,7 +58,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno input_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict) - device: str | None = None # cuda | cpu | mp + device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps" # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. use_amp: bool = False diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 46c15d92..c7ad2bbd 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -57,25 +57,23 @@ def auto_select_torch_device() -> torch.device: def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device: """Given a string, return a torch.device with checks on whether the device is available.""" try_device = str(try_device) - match try_device: - case "cuda": - assert torch.cuda.is_available() - device = torch.device("cuda") - case "mps": - assert torch.backends.mps.is_available() - device = torch.device("mps") - case "xpu": - assert torch.xpu.is_available() - device = torch.device("xpu") - case "cpu": - device = torch.device("cpu") - if log: - logging.warning("Using CPU, this will be slow.") - case _: - device = torch.device(try_device) - if log: - logging.warning(f"Using custom {try_device} device.") - + if try_device.startswith("cuda"): + assert torch.cuda.is_available() + device = torch.device(try_device) + elif try_device == "mps": + assert torch.backends.mps.is_available() + device = torch.device("mps") + elif try_device == "xpu": + assert torch.xpu.is_available() + device = torch.device("xpu") + elif try_device == "cpu": + device = torch.device("cpu") + if log: + logging.warning("Using CPU, this will be slow.") + else: + device = torch.device(try_device) + if log: + logging.warning(f"Using custom {try_device} device.") return device @@ -108,7 +106,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device): def is_torch_device_available(try_device: str) -> bool: try_device = str(try_device) # Ensure try_device is a string - if try_device == "cuda": + if try_device.startswith("cuda"): return torch.cuda.is_available() elif try_device == "mps": return torch.backends.mps.is_available()