backup wip

This commit is contained in:
Alexander Soare
2024-03-19 18:50:04 +00:00
parent ea17f4ce50
commit 896a11f60e
16 changed files with 169 additions and 138 deletions

View File

@@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
"""Implementation of TD-MPC learning + inference."""
def __init__(self, cfg, device):
super().__init__()
super().__init__(None)
self.action_dim = cfg.action_dim
self.cfg = cfg
@@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
self.model_target.load_state_dict(d["model_target"])
@torch.no_grad()
def select_action(self, observation, step_count):
def select_actions(self, observation, step_count):
if observation["image"].shape[0] != 1:
raise NotImplementedError("Batch size > 1 not handled")
t0 = step_count.item() == 0
obs = {
@@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
"rgb": observation["image"].contiguous(),
"state": observation["state"].contiguous(),
}
action = self.act(obs, t0=t0, step=self.step.item())
# Note: unsqueeze needed because `act` still uses non-batch logic.
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
return action
@torch.no_grad()
@@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
if self.cfg.mpc:
a = self.plan(z, t0=t0, step=step)
else:
a = self.model.pi(z, self.cfg.min_std * self.model.training)
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
return a
@torch.no_grad()