fix: support cuda:0, cuda:1 in string selection (#2256)
* fix * update func 2 * update nightly * fix quality * ignore test_dynamixel
This commit is contained in:
2
.github/workflows/nightly.yml
vendored
2
.github/workflows/nightly.yml
vendored
@@ -189,5 +189,5 @@ jobs:
|
|||||||
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
python -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}'); print(f'Number of GPUs: {torch.cuda.device_count()}')"
|
||||||
|
|
||||||
- name: Run multi-GPU training tests
|
- name: Run multi-GPU training tests
|
||||||
run: pytest tests/training/test_multi_gpu.py -vv --maxfail=3
|
run: pytest tests -vv --maxfail=10 --ignore=tests/motors/test_dynamixel.py
|
||||||
timeout-minutes: 10
|
timeout-minutes: 10
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
|||||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||||
|
|
||||||
device: str | None = None # cuda | cpu | mp
|
device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps"
|
||||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||||
# automatic gradient scaling is used.
|
# automatic gradient scaling is used.
|
||||||
use_amp: bool = False
|
use_amp: bool = False
|
||||||
|
|||||||
@@ -57,25 +57,23 @@ def auto_select_torch_device() -> torch.device:
|
|||||||
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||||
"""Given a string, return a torch.device with checks on whether the device is available."""
|
"""Given a string, return a torch.device with checks on whether the device is available."""
|
||||||
try_device = str(try_device)
|
try_device = str(try_device)
|
||||||
match try_device:
|
if try_device.startswith("cuda"):
|
||||||
case "cuda":
|
assert torch.cuda.is_available()
|
||||||
assert torch.cuda.is_available()
|
device = torch.device(try_device)
|
||||||
device = torch.device("cuda")
|
elif try_device == "mps":
|
||||||
case "mps":
|
assert torch.backends.mps.is_available()
|
||||||
assert torch.backends.mps.is_available()
|
device = torch.device("mps")
|
||||||
device = torch.device("mps")
|
elif try_device == "xpu":
|
||||||
case "xpu":
|
assert torch.xpu.is_available()
|
||||||
assert torch.xpu.is_available()
|
device = torch.device("xpu")
|
||||||
device = torch.device("xpu")
|
elif try_device == "cpu":
|
||||||
case "cpu":
|
device = torch.device("cpu")
|
||||||
device = torch.device("cpu")
|
if log:
|
||||||
if log:
|
logging.warning("Using CPU, this will be slow.")
|
||||||
logging.warning("Using CPU, this will be slow.")
|
else:
|
||||||
case _:
|
device = torch.device(try_device)
|
||||||
device = torch.device(try_device)
|
if log:
|
||||||
if log:
|
logging.warning(f"Using custom {try_device} device.")
|
||||||
logging.warning(f"Using custom {try_device} device.")
|
|
||||||
|
|
||||||
return device
|
return device
|
||||||
|
|
||||||
|
|
||||||
@@ -108,7 +106,7 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
|
|||||||
|
|
||||||
def is_torch_device_available(try_device: str) -> bool:
|
def is_torch_device_available(try_device: str) -> bool:
|
||||||
try_device = str(try_device) # Ensure try_device is a string
|
try_device = str(try_device) # Ensure try_device is a string
|
||||||
if try_device == "cuda":
|
if try_device.startswith("cuda"):
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
elif try_device == "mps":
|
elif try_device == "mps":
|
||||||
return torch.backends.mps.is_available()
|
return torch.backends.mps.is_available()
|
||||||
|
|||||||
Reference in New Issue
Block a user