Enable logging all the information returned by the forward methods of policies (#151)
This commit is contained in:
@@ -38,7 +38,8 @@ class Policy(Protocol):
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
|
||||
Reference in New Issue
Block a user