Move normalization to policy for act and diffusion (#90)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-04-25 11:47:38 +02:00
committed by GitHub
parent c1bcf857c5
commit e760e4cd63
25 changed files with 543 additions and 288 deletions

View File

@@ -8,23 +8,30 @@ class ActionChunkingTransformerConfig:
Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer".
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `state_dim`, `action_dim` and `camera_names`.
Those are: `input_shapes` and 'output_shapes`.
Args:
state_dim: Dimensionality of the observation state space (excluding images).
action_dim: Dimensionality of the action space.
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
camera_names: The (unique) set of names for the cameras.
chunk_size: The size of the action prediction "chunks" in units of environment steps.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
image_normalization_mean: Value to subtract from the input image pixels (inputs are assumed to be in
[0, 1]) for normalization.
image_normalization_std: Value by which to divide the input image pixels (after the mean has been
subtracted).
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
normalize_input_modes: A dictionary with key represents the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two availables
modes are "mean_std" which substracts the mean and divide by the standard
deviation and "min_max" which rescale in a [-1, 1] range.
unnormalize_output_modes: Similar dictionary as `normalize_input_modes`, but to unormalize in original scale.
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
use_pretrained_backbone: Whether the backbone should be initialized with pretrained weights from
torchvision.
@@ -50,21 +57,35 @@ class ActionChunkingTransformerConfig:
is enabled. Loss is then calculated as: `reconstruction_loss + kl_weight * kld_loss`.
"""
# Environment.
state_dim: int = 14
action_dim: int = 14
# Inputs / output structure.
# Input / output structure.
n_obs_steps: int = 1
camera_names: tuple[str] = ("top",)
chunk_size: int = 100
n_action_steps: int = 100
# Vision preprocessing.
image_normalization_mean: tuple[float, float, float] = field(
default_factory=lambda: [0.485, 0.456, 0.406]
input_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"observation.images.top": [3, 480, 640],
"observation.state": [14],
}
)
output_shapes: dict[str, list[str]] = field(
default_factory=lambda: {
"action": [14],
}
)
# Normalization / Unnormalization
normalize_input_modes: dict[str, str] = field(
default_factory=lambda: {
"observation.image": "mean_std",
"observation.state": "mean_std",
}
)
unnormalize_output_modes: dict[str, str] = field(
default_factory=lambda: {
"action": "mean_std",
}
)
image_normalization_std: tuple[float, float, float] = field(default_factory=lambda: [0.229, 0.224, 0.225])
# Architecture.
# Vision backbone.
@@ -117,7 +138,10 @@ class ActionChunkingTransformerConfig:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.camera_names != ["top"]:
raise ValueError(f"For now, `camera_names` can only be ['top']. Got {self.camera_names}.")
if len(set(self.camera_names)) != len(self.camera_names):
raise ValueError(f"`camera_names` should not have any repeated entries. Got {self.camera_names}.")
# Check that there is only one image.
# TODO(alexander-soare): generalize this to multiple images.
if (
sum(k.startswith("observation.images.") for k in self.input_shapes) != 1
or "observation.images.top" not in self.input_shapes
):
raise ValueError('For now, only "observation.images.top" is accepted for an image input.')