Improves Type Annotations (#252)

This commit is contained in:
Wael Karkoub
2024-06-10 19:09:48 +01:00
committed by GitHub
parent a06598678c
commit 54c9776bde
7 changed files with 54 additions and 23 deletions

View File

@@ -57,7 +57,7 @@ class Policy(Protocol):
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals

View File

@@ -134,7 +134,7 @@ class TDMPCPolicy(nn.Module, PyTorchModelHubMixin):
self._prev_mean: torch.Tensor | None = None
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]):
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
batch = self.normalize_inputs(batch)
batch["observation.image"] = batch[self.input_image_key]