Add run on cpu-only compatibility

This commit is contained in:
Simon Alibert
2024-03-03 12:47:26 +01:00
parent 661bda45ea
commit b33ec5a630
5 changed files with 100 additions and 95 deletions

View File

@@ -88,14 +88,14 @@ class TOLD(nn.Module):
class TDMPC(nn.Module):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg):
def __init__(self, cfg, device):
super().__init__()
self.action_dim = cfg.action_dim
self.cfg = cfg
self.device = torch.device("cuda")
self.device = torch.device(device)
self.std = h.linear_schedule(cfg.std_schedule, 0)
self.model = TOLD(cfg).cuda()
self.model = TOLD(cfg).cuda() if torch.cuda.is_available() and device == "cuda" else TOLD(cfg)
self.model_target = deepcopy(self.model)
self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr)
self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=self.cfg.lr)