Add get_safe_torch_device in policies
This commit is contained in:
@@ -12,7 +12,7 @@ from lerobot.common.datasets.factory import make_offline_buffer
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import format_big_number, init_logging, set_seed
|
||||
from lerobot.common.utils import format_big_number, get_safe_torch_device, init_logging, set_seed
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
@@ -117,10 +117,8 @@ def train(cfg: dict, out_dir=None, job_name=None):
|
||||
|
||||
init_logging()
|
||||
|
||||
if cfg.device == "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
else:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
# Check device is available
|
||||
get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
Reference in New Issue
Block a user