Fix ACT temporal ensembling (#319)

This commit is contained in:
Alexander Soare
2024-07-16 10:27:21 +01:00
committed by GitHub
parent 5e54e39795
commit c0101f0948
7 changed files with 173 additions and 31 deletions

View File

@@ -76,12 +76,10 @@ class ACTConfig:
documentation in the policy class).
latent_dim: The VAE's latent dimension.
n_vae_encoder_layers: The number of transformer layers to use for the VAE's encoder.
temporal_ensemble_momentum: Exponential moving average (EMA) momentum parameter (α) for ensembling
actions for a given time step over multiple policy invocations. Updates are calculated as:
x⁻ₙ = αx⁻ₙ₋₁ + (1-α)xₙ. Note that the ACT paper and original ACT code describes a different
parameter here: they refer to a weighting scheme wᵢ = exp(-m⋅i) and set m = 0.01. With our
formulation, this is equivalent to α = exp(-0.01) ≈ 0.99. When this parameter is provided, we
require `n_action_steps == 1` (since we need to query the policy every step anyway).
temporal_ensemble_coeff: Coefficient for the exponential weighting scheme to apply for temporal
ensembling. Defaults to None which means temporal ensembling is not used. `n_action_steps` must be
1 when using this feature, as inference needs to happen at every step to form an ensemble. For
more information on how ensembling works, please see `ACTTemporalEnsembler`.
dropout: Dropout to use in the transformer layers (see code for details).
kl_weight: The weight to use for the KL-divergence component of the loss if the variational objective
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
@@ -139,7 +137,8 @@ class ACTConfig:
n_vae_encoder_layers: int = 4
# Inference.
temporal_ensemble_momentum: float | None = None
# Note: the value used in ACT when temporal ensembling is enabled is 0.01.
temporal_ensemble_coeff: float | None = None
# Training and loss computation.
dropout: float = 0.1
@@ -151,7 +150,7 @@ class ACTConfig:
raise ValueError(
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
)
if self.temporal_ensemble_momentum is not None and self.n_action_steps > 1:
if self.temporal_ensemble_coeff is not None and self.n_action_steps > 1:
raise NotImplementedError(
"`n_action_steps` must be 1 when using temporal ensembling. This is "
"because the policy needs to be queried every step to compute the ensembled action."

View File

@@ -77,12 +77,15 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
if config.temporal_ensemble_coeff is not None:
self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
if self.config.temporal_ensemble_momentum is not None:
self._ensembled_actions = None
if self.config.temporal_ensemble_coeff is not None:
self.temporal_ensembler.reset()
else:
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@@ -100,24 +103,12 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
if len(self.expected_image_keys) > 0:
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
# the first action.
if self.config.temporal_ensemble_momentum is not None:
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
if self.config.temporal_ensemble_coeff is not None:
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
actions = self.unnormalize_outputs({"action": actions})["action"]
if self._ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self._ensembled_actions = actions.clone()
else:
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the EMA update for those entries.
alpha = self.config.temporal_ensemble_momentum
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
# The last action, which has no prior moving average, needs to get concatenated onto the end.
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
# "Consume" the first action.
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
action = self.temporal_ensembler.update(actions)
return action
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
@@ -162,6 +153,97 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
return loss_dict
class ACTTemporalEnsembler:
def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None:
"""Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705.
The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action.
They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the
coefficient works:
- Setting it to 0 uniformly weighs all actions.
- Setting it positive gives more weight to older actions.
- Setting it negative gives more weight to newer actions.
NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This
results in older actions being weighed more highly than newer actions (the experiments documented in
https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be
detrimental: doing so aggressively may diminish the benefits of action chunking).
Here we use an online method for computing the average rather than caching a history of actions in
order to compute the average offline. For a simple 1D sequence it looks something like:
```
import torch
seq = torch.linspace(8, 8.5, 100)
print(seq)
m = 0.01
exp_weights = torch.exp(-m * torch.arange(len(seq)))
print(exp_weights)
# Calculate offline
avg = (exp_weights * seq).sum() / exp_weights.sum()
print("offline", avg)
# Calculate online
for i, item in enumerate(seq):
if i == 0:
avg = item
continue
avg *= exp_weights[:i].sum()
avg += item * exp_weights[i]
avg /= exp_weights[:i+1].sum()
print("online", avg)
```
"""
self.chunk_size = chunk_size
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
self.reset()
def reset(self):
"""Resets the online computation variables."""
self.ensembled_actions = None
# (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence.
self.ensembled_actions_count = None
def update(self, actions: Tensor) -> Tensor:
"""
Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all
time steps, and pop/return the next batch of actions in the sequence.
"""
self.ensemble_weights = self.ensemble_weights.to(device=actions.device)
self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device)
if self.ensembled_actions is None:
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
# time step of the episode.
self.ensembled_actions = actions.clone()
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
# operations later.
self.ensembled_actions_count = torch.ones(
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
)
else:
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
# the online update for those entries.
self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1]
self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count]
self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count]
self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size)
# The last action, which has no prior online average, needs to get concatenated onto the end.
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
self.ensembled_actions_count = torch.cat(
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
)
# "Consume" the first action.
action, self.ensembled_actions, self.ensembled_actions_count = (
self.ensembled_actions[:, 0],
self.ensembled_actions[:, 1:],
self.ensembled_actions_count[1:],
)
return action
class ACT(nn.Module):
"""Action Chunking Transformer: The underlying neural network for ACTPolicy.

View File

@@ -75,7 +75,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@@ -107,7 +107,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1

View File

@@ -103,7 +103,7 @@ policy:
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
temporal_ensemble_coeff: null
# Training and loss computation.
dropout: 0.1