forked from tangger/lerobot
fix: support cuda:0, cuda:1 in string selection (#2256)
* fix * update func 2 * update nightly * fix quality * ignore test_dynamixel
This commit is contained in:
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@@ -189,5 +189,5 @@ jobs:
|
||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||
|
||||
- name: Run multi-GPU training tests
|
||||
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
|
||||
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/test_dynamixel.py
|
||||
timeout-minutes: 10
|
||||
|
||||
@@ -58,7 +58,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
@@ -57,25 +57,23 @@ def auto_select_torch_device() -> torch.device:
|
||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||
try_device = str(try_device)
|
||||
match try_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "xpu":
|
||||
assert torch.xpu.is_available()
|
||||
device = torch.device("xpu")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
|
||||
if try_device.startswith("cuda"):
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device(try_device)
|
||||
elif try_device == "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
elif try_device == "xpu":
|
||||
assert torch.xpu.is_available()
|
||||
device = torch.device("xpu")
|
||||
elif try_device == "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
else:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
return device
|
||||
|
||||
|
||||
@@ -108,7 +106,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
||||
|
||||
def is_torch_device_available(try_device: str) -> bool:
|
||||
try_device = str(try_device) # Ensure try_device is a string
|
||||
if try_device == "cuda":
|
||||
if try_device.startswith("cuda"):
|
||||
return torch.cuda.is_available()
|
||||
elif try_device == "mps":
|
||||
return torch.backends.mps.is_available()
|
||||
|
||||
Reference in New Issue
Block a user