forked from tangger/lerobot
Refactor TD-MPC (#103)
Co-authored-by: Cadene <re.cadene@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -43,15 +43,16 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: DiffusionConfig | None = None,
|
||||
dataset_stats=None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
# TODO(alexander-soare): LR scheduler will be removed.
|
||||
if config is None:
|
||||
config = DiffusionConfig()
|
||||
self.config = config
|
||||
@@ -88,7 +89,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
}
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor], **_) -> Tensor:
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method handles caching a history of observations and an action trajectory generated by the
|
||||
@@ -136,7 +137,7 @@ class DiffusionPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **_) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
Reference in New Issue
Block a user