wip
This commit is contained in:
@@ -24,20 +24,20 @@ class Policy(Protocol):
|
||||
Does things like clearing caches.
|
||||
"""
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Wired to `select_action`."""
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs):
|
||||
Returns a dictionary with "loss" and maybe other information.
|
||||
"""
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]):
|
||||
"""Return one action to run in the environment (potentially in batch mode).
|
||||
|
||||
When the model uses a history of observations, or outputs a sequence of actions, this method deals
|
||||
with caching.
|
||||
"""
|
||||
|
||||
def compute_loss(self, batch: dict[str, Tensor], **kwargs):
|
||||
"""Runs the batch through the model and computes the loss for training or validation."""
|
||||
|
||||
def update(self, batch, **kwargs):
|
||||
def update(self, batch):
|
||||
"""Does compute_loss then an optimization step.
|
||||
|
||||
TODO(alexander-soare): We will move the optimization step back into the training loop, so this will
|
||||
|
||||
Reference in New Issue
Block a user