Enable logging all the information returned by the forward methods of policies (#151)

This commit is contained in:
Alexander Soare
2024-05-10 07:45:32 +01:00
committed by GitHub
parent b187942db4
commit 1249aee3ac
5 changed files with 12 additions and 4 deletions

View File

@@ -38,7 +38,8 @@ class Policy(Protocol):
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
Returns a dictionary with "loss" and maybe other information.
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
"""
def select_action(self, batch: dict[str, Tensor]):