Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler: