Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError