Fix ACT temporal ensembling (#319)
This commit is contained in:
@@ -76,12 +76,10 @@ class ACTConfig:
|
||||
documentation in the policy class).
|
||||
latent_dim: The VAE's latent dimension.
|
||||
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
|
||||
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
|
||||
actions for a given time step over multiple policy invocations. Updates are calculated as:
|
||||
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
|
||||
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
|
||||
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
|
||||
require `n_action_steps == 1` (since we need to query the policy every step anyway).
|
||||
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
|
||||
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
|
||||
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
|
||||
more information on how ensembling works, please see `ACTTemporalEnsembler`.
|
||||
dropout: Dropout to use in the transformer layers (see code for details).
|
||||
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
|
||||
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
|
||||
@@ -139,7 +137,8 @@ class ACTConfig:
|
||||
n_vae_encoder_layers: int = 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: float | None = None
|
||||
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
|
||||
temporal_ensemble_coeff: float | None = None
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: float = 0.1
|
||||
@@ -151,7 +150,7 @@ class ACTConfig:
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
)
|
||||
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
|
||||
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
|
||||
raise NotImplementedError(
|
||||
"`n_action_steps` must be 1 when using temporal ensembling. This is "
|
||||
"because the policy needs to be queried every step to compute the ensembled action."
|
||||
|
||||
Reference in New Issue
Block a user