Act temporal ensembling (#186)
This commit is contained in:
@@ -61,7 +61,7 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
config = ACTConfig()
|
||||
self.config = config
|
||||
self.config: ACTConfig = config
|
||||
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
@@ -81,7 +81,9 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
|
||||
def reset(self):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
if self.config.n_action_steps is not None:
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
self._ensembled_actions = None
|
||||
else:
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@@ -97,6 +99,28 @@ class ACTPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4)
|
||||
|
||||
# If we are doing temporal ensembling, keep track of the exponential moving average (EMA), and return
|
||||
# the first action.
|
||||
if self.config.temporal_ensemble_momentum is not None:
|
||||
actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
if self._ensembled_actions is None:
|
||||
# Initializes `self._ensembled_action` to the sequence of actions predicted during the first
|
||||
# time step of the episode.
|
||||
self._ensembled_actions = actions.clone()
|
||||
else:
|
||||
# self._ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
# the EMA update for those entries.
|
||||
alpha = self.config.temporal_ensemble_momentum
|
||||
self._ensembled_actions = alpha * self._ensembled_actions + (1 - alpha) * actions[:, :-1]
|
||||
# The last action, which has no prior moving average, needs to get concatenated onto the end.
|
||||
self._ensembled_actions = torch.cat([self._ensembled_actions, actions[:, -1:]], dim=1)
|
||||
# "Consume" the first action.
|
||||
action, self._ensembled_actions = self._ensembled_actions[:, 0], self._ensembled_actions[:, 1:]
|
||||
return action
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.model(batch)[0][:, : self.config.n_action_steps]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user