Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -43,15 +43,16 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
config: DiffusionConfig | None = None,
dataset_stats=None,
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
the configuration class is used.
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
that they will be passed with a call to `load_state_dict` before the policy is used.
"""
super().__init__()
# TODO(alexander-soare): LR scheduler will be removed.
if config is None:
config = DiffusionConfig()
self.config = config
@@ -88,7 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
}
@torch.no_grad
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.
This method handles caching a history of observations and an action trajectory generated by the
@@ -136,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = self.normalize_targets(batch)