Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
)
loss_dict["kld_loss"] = mean_kld.item()
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
loss = l1_loss + mean_kld * self.config.kl_weight
else:
loss_dict["loss"] = l1_loss
loss = l1_loss
return loss_dict
return loss, loss_dict
class ACTTemporalEnsembler:

View File

@@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
@@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
)
batch = self.normalize_targets(batch)
loss = self.diffusion.compute_loss(batch)
return {"loss": loss}
# no output_dict so returning None
return loss, None
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:

View File

@@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
"""
raise NotImplementedError
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
@abc.abstractmethod
def forward(self, batch: dict[str, Tensor]) -> dict:
"""Run the batch through the model and compute the loss for training or validation.
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
"""_summary_
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
other items should be logging-friendly, native Python types.
Args:
batch (dict[str, Tensor]): _description_
Returns:
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
is a Tensor, all other items should be logging-friendly, native Python types.
"""
raise NotImplementedError

View File

@@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
return G
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
@@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
"Q_value_loss": q_value_loss.item(),
"V_value_loss": v_value_loss.item(),
"pi_loss": pi_loss.item(),
"loss": loss,
"sum_loss": loss.item() * self.config.horizon,
}
)
@@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
batch[key] = batch[key].transpose(1, 0)
return info
return loss, info
def update(self):
"""Update the target model's parameters with an EMA step."""

View File

@@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
action = self._queues["action"].popleft()
return action
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
@@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
loss, n_different_codes, n_different_combinations, recon_l1_error = (
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
)
return {
"loss": loss,
return loss, {
"n_different_codes": n_different_codes,
"n_different_combinations": n_different_combinations,
"recon_l1_error": recon_l1_error,
}
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
_, loss_dict = self.vqbet(batch, rollout=False)
loss = loss_dict.pop("loss")
return loss_dict
return loss, loss_dict
class SpatialSoftmax(nn.Module):
@@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
)
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
# Input validation.
assert set(batch).issuperset({"observation.state", "observation.images"})
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
@@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
param.requires_grad = False
return loss, n_different_codes, n_different_combinations, recon_l1_error
def forward(self, x, **kwargs):
def forward(self, x, **kwargs) -> dict:
# N is the batch size, and T is number of action query tokens, which are process through same GPT
N, T, _ = x.shape
# we calculate N and T side parallely. Thus, the dimensions would be