Add get_safe_torch_device in policies
This commit is contained in:
@@ -7,6 +7,7 @@ import torchvision.transforms as transforms
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.act.detr_vae import build
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
def build_act_model_and_optimizer(cfg):
|
||||
@@ -45,7 +46,7 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
super().__init__(n_action_steps)
|
||||
self.cfg = cfg
|
||||
self.n_action_steps = n_action_steps
|
||||
self.device = device
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.model, self.optimizer = build_act_model_and_optimizer(cfg)
|
||||
self.kl_weight = self.cfg.kl_weight
|
||||
logging.info(f"KL Weight {self.kl_weight}")
|
||||
|
||||
@@ -8,6 +8,7 @@ from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.diffusion.diffusion_unet_image_policy import DiffusionUnetImagePolicy
|
||||
from lerobot.common.policies.diffusion.model.lr_scheduler import get_scheduler
|
||||
from lerobot.common.policies.diffusion.model.multi_image_obs_encoder import MultiImageObsEncoder
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
|
||||
class DiffusionPolicy(AbstractPolicy):
|
||||
@@ -62,9 +63,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.device = torch.device(cfg_device)
|
||||
if torch.cuda.is_available() and cfg_device == "cuda":
|
||||
self.diffusion.cuda()
|
||||
self.device = get_safe_torch_device(cfg_device)
|
||||
self.diffusion.to(self.device)
|
||||
|
||||
self.ema = None
|
||||
if self.cfg.use_ema:
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch.nn as nn
|
||||
|
||||
import lerobot.common.policies.tdmpc.helper as h
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.utils import get_safe_torch_device
|
||||
|
||||
FIRST_FRAME = 0
|
||||
|
||||
@@ -94,9 +95,10 @@ class TDMPC(AbstractPolicy):
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device(device)
|
||||
self.device = get_safe_torch_device(device)
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
|
||||
self.model = TOLD(cfg)
|
||||
self.model.to(self.device)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
|
||||
@@ -6,6 +6,26 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
|
||||
match cfg_device:
|
||||
case "cuda":
|
||||
assert torch.cuda.is_available()
|
||||
device = torch.device("cuda")
|
||||
case "mps":
|
||||
assert torch.backends.mps.is_available()
|
||||
device = torch.device("mps")
|
||||
case "cpu":
|
||||
device = torch.device("cpu")
|
||||
if log:
|
||||
logging.warning("Using CPU, this will be slow.")
|
||||
case _:
|
||||
device = torch.device(cfg_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {cfg_device} device.")
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
|
||||
Reference in New Issue
Block a user