Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -144,7 +144,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -169,11 +169,11 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
loss = l1_loss + mean_kld * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
loss = l1_loss
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
@@ -143,7 +143,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
@@ -153,7 +153,8 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
return {"loss": loss}
|
||||
# no output_dict so returning None
|
||||
return loss, None
|
||||
|
||||
|
||||
def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMScheduler:
|
||||
|
||||
@@ -163,12 +163,17 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# TODO(aliberts, rcadene): split into 'forward' and 'compute_loss'?
|
||||
@abc.abstractmethod
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict:
|
||||
"""Run the batch through the model and compute the loss for training or validation.
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict | None]:
|
||||
"""_summary_
|
||||
|
||||
Returns a dictionary with "loss" and potentially other information. Apart from "loss" which is a Tensor, all
|
||||
other items should be logging-friendly, native Python types.
|
||||
Args:
|
||||
batch (dict[str, Tensor]): _description_
|
||||
|
||||
Returns:
|
||||
tuple[Tensor, dict | None]: The loss and potentially other information. Apart from the loss which
|
||||
is a Tensor, all other items should be logging-friendly, native Python types.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0)
|
||||
return G
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
@@ -495,7 +495,6 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"V_value_loss": v_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
}
|
||||
)
|
||||
@@ -505,7 +504,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
return loss, info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's parameters with an EMA step."""
|
||||
|
||||
@@ -156,7 +156,7 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
@@ -170,16 +170,16 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
loss, n_different_codes, n_different_combinations, recon_l1_error = (
|
||||
self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"])
|
||||
)
|
||||
return {
|
||||
"loss": loss,
|
||||
return loss, {
|
||||
"n_different_codes": n_different_codes,
|
||||
"n_different_combinations": n_different_combinations,
|
||||
"recon_l1_error": recon_l1_error,
|
||||
}
|
||||
# if Residual VQ is already trained, VQ-BeT trains its GPT and bin prediction head / offset prediction head parts.
|
||||
_, loss_dict = self.vqbet(batch, rollout=False)
|
||||
loss = loss_dict.pop("loss")
|
||||
|
||||
return loss_dict
|
||||
return loss, loss_dict
|
||||
|
||||
|
||||
class SpatialSoftmax(nn.Module):
|
||||
@@ -342,7 +342,7 @@ class VQBeTModel(nn.Module):
|
||||
torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]),
|
||||
)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor:
|
||||
def forward(self, batch: dict[str, Tensor], rollout: bool) -> tuple[dict, dict]:
|
||||
# Input validation.
|
||||
assert set(batch).issuperset({"observation.state", "observation.images"})
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -482,7 +482,7 @@ class VQBeTHead(nn.Module):
|
||||
param.requires_grad = False
|
||||
return loss, n_different_codes, n_different_combinations, recon_l1_error
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
def forward(self, x, **kwargs) -> dict:
|
||||
# N is the batch size, and T is number of action query tokens, which are process through same GPT
|
||||
N, T, _ = x.shape
|
||||
# we calculate N and T side parallely. Thus, the dimensions would be
|
||||
|
||||
Reference in New Issue
Block a user