backup wip
This commit is contained in:
@@ -90,7 +90,7 @@ class TDMPC(AbstractPolicy):
|
||||
"""Implementation of TD-MPC learning + inference."""
|
||||
|
||||
def __init__(self, cfg, device):
|
||||
super().__init__()
|
||||
super().__init__(None)
|
||||
self.action_dim = cfg.action_dim
|
||||
|
||||
self.cfg = cfg
|
||||
@@ -125,7 +125,10 @@ class TDMPC(AbstractPolicy):
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, observation, step_count):
|
||||
def select_actions(self, observation, step_count):
|
||||
if observation["image"].shape[0] != 1:
|
||||
raise NotImplementedError("Batch size > 1 not handled")
|
||||
|
||||
t0 = step_count.item() == 0
|
||||
|
||||
obs = {
|
||||
@@ -133,7 +136,8 @@ class TDMPC(AbstractPolicy):
|
||||
"rgb": observation["image"].contiguous(),
|
||||
"state": observation["state"].contiguous(),
|
||||
}
|
||||
action = self.act(obs, t0=t0, step=self.step.item())
|
||||
# Note: unsqueeze needed because `act` still uses non-batch logic.
|
||||
action = self.act(obs, t0=t0, step=self.step.item()).unsqueeze(0)
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -144,7 +148,7 @@ class TDMPC(AbstractPolicy):
|
||||
if self.cfg.mpc:
|
||||
a = self.plan(z, t0=t0, step=step)
|
||||
else:
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training)
|
||||
a = self.model.pi(z, self.cfg.min_std * self.model.training).squeeze(0)
|
||||
return a
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user