forked from tangger/lerobot
Fix ACT temporal ensembling (#319)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
@@ -26,6 +27,7 @@ from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.common.policies.factory import (
|
||||
_policy_cfg_from_hydra_cfg,
|
||||
get_policy_and_config_classes,
|
||||
@@ -33,7 +35,7 @@ from lerobot.common.policies.factory import (
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.train import make_optimizer_and_scheduler
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
@@ -390,3 +392,62 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
assert torch.isclose(param_stats[key], saved_param_stats[key], rtol=50, atol=1e-7).all()
|
||||
for key in saved_actions:
|
||||
assert torch.isclose(actions[key], saved_actions[key], rtol=0.1, atol=1e-7).all()
|
||||
|
||||
|
||||
def test_act_temporal_ensembler():
|
||||
"""Check that the online method in ACTTemporalEnsembler matches a simple offline calculation."""
|
||||
temporal_ensemble_coeff = 0.01
|
||||
chunk_size = 100
|
||||
episode_length = 101
|
||||
ensembler = ACTTemporalEnsembler(temporal_ensemble_coeff, chunk_size)
|
||||
# An batch of arbitrary sequences of 1D actions we wish to compute the average over. We'll keep the
|
||||
# "action space" in [-1, 1]. Apart from that, there is no real reason for the numbers chosen.
|
||||
with seeded_context(0):
|
||||
# Dimension is (batch, episode_length, chunk_size, action_dim(=1))
|
||||
# Stepping through the episode_length dim is like running inference at each rollout step and getting
|
||||
# a different action chunk.
|
||||
batch_seq = torch.stack(
|
||||
[
|
||||
torch.rand(episode_length, chunk_size) * 0.05 - 0.6,
|
||||
torch.rand(episode_length, chunk_size) * 0.02 - 0.01,
|
||||
torch.rand(episode_length, chunk_size) * 0.2 + 0.3,
|
||||
],
|
||||
dim=0,
|
||||
).unsqueeze(-1) # unsqueeze for action dim
|
||||
batch_size = batch_seq.shape[0]
|
||||
# Exponential weighting (normalized). Unsqueeze once to match the position of the `episode_length`
|
||||
# dimension of `batch_seq`.
|
||||
weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)).unsqueeze(-1)
|
||||
|
||||
# Simulate stepping through a rollout and computing a batch of actions with model on each step.
|
||||
for i in range(episode_length):
|
||||
# Mock a batch of actions.
|
||||
actions = torch.zeros(size=(batch_size, chunk_size, 1)) + batch_seq[:, i]
|
||||
online_avg = ensembler.update(actions)
|
||||
# Simple offline calculation: avg = Σ(aᵢ*wᵢ) / Σ(wᵢ).
|
||||
# Note: The complicated bit here is the slicing. Think about the (episode_length, chunk_size) grid.
|
||||
# What we want to do is take diagonal slices across it starting from the left.
|
||||
# eg: chunk_size=4, episode_length=6
|
||||
# ┌───────┐
|
||||
# │0 1 2 3│
|
||||
# │1 2 3 4│
|
||||
# │2 3 4 5│
|
||||
# │3 4 5 6│
|
||||
# │4 5 6 7│
|
||||
# │5 6 7 8│
|
||||
# └───────┘
|
||||
chunk_indices = torch.arange(min(i, chunk_size - 1), -1, -1)
|
||||
episode_step_indices = torch.arange(i + 1)[-len(chunk_indices) :]
|
||||
seq_slice = batch_seq[:, episode_step_indices, chunk_indices]
|
||||
offline_avg = (
|
||||
einops.reduce(seq_slice * weights[: i + 1], "b s 1 -> b 1", "sum") / weights[: i + 1].sum()
|
||||
)
|
||||
# Sanity check. The average should be between the extrema.
|
||||
assert torch.all(einops.reduce(seq_slice, "b s 1 -> b 1", "min") <= offline_avg)
|
||||
assert torch.all(offline_avg <= einops.reduce(seq_slice, "b s 1 -> b 1", "max"))
|
||||
# Selected atol=1e-4 keeping in mind actions in [-1, 1] and excepting 0.01% error.
|
||||
assert torch.allclose(online_avg, offline_avg, atol=1e-4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_act_temporal_ensembler()
|
||||
|
||||
Reference in New Issue
Block a user