Add run on cpu-only compatibility
This commit is contained in:
@@ -88,14 +88,14 @@ class TOLD(nn.Module):
|
||||
class TDMPC(nn.Module):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg):
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
self.device = torch.device("cuda")
|
||||
self.device = torch.device(device)
|
||||
self.std = h.linear_schedule(cfg.std_schedule, 0)
|
||||
self.model = TOLD(cfg).cuda()
|
||||
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
|
||||
self.model_target = deepcopy(self.model)
|
||||
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
|
||||
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)
|
||||
|
||||
Reference in New Issue
Block a user