Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -31,11 +31,17 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
name = "act"
def __init__(self, config: ACTConfig | None = None, dataset_stats=None):
def __init__(
self,
config: ACTConfig | None = None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
if config is None:
@@ -58,7 +64,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
@@ -81,7 +87,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch, **_) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)