This commit is contained in:
Alexander Soare
2024-04-17 16:21:37 +01:00
parent 63e5ec6483
commit 2298ddf226
3 changed files with 26 additions and 22 deletions

View File

@@ -24,20 +24,20 @@ class Policy(Protocol):
Does things like clearing caches.
"""
def forward(self, batch: dict[str, Tensor], **kwargs):
"""Wired to `select_action`."""
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
def select_action(self, batch: dict[str, Tensor], **kwargs):
Returns a dictionary with "loss" and maybe other information.
"""
def select_action(self, batch: dict[str, Tensor]):
"""Return one action to run in the environment (potentially in batch mode).
When the model uses a history of observations, or outputs a sequence of actions, this method deals
with caching.
"""
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
"""Runs the batch through the model and computes the loss for training or validation."""
def update(self, batch, **kwargs):
def update(self, batch):
"""Does compute_loss then an optimization step.
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will