Add real-world support for ACT on Aloha/Aloha2 (#228)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Remi
2024-05-31 15:31:02 +02:00
committed by GitHub
parent 504d2aaf48
commit d585c73f9f
22 changed files with 525 additions and 140 deletions

View File

@@ -45,6 +45,9 @@ import itertools
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
@@ -52,6 +55,7 @@ available_tasks_per_env = {
],
"pusht": ["PushT-v0"],
"xarm": ["XarmLift-v0"],
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
}
available_envs = list(available_tasks_per_env.keys())
@@ -77,6 +81,23 @@ available_datasets_per_env = {
"lerobot/xarm_push_medium_image",
"lerobot/xarm_push_medium_replay_image",
],
"dora_aloha_real": [
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
],
}
available_real_world_datasets = [
@@ -108,16 +129,19 @@ available_datasets = list(
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
)
# lists all available policies from `lerobot/common/policies` by their class attribute: `name`.
available_policies = [
"act",
"diffusion",
"tdmpc",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion"],
"xarm": ["tdmpc"],
"dora_aloha_real": ["act_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]

View File

@@ -55,11 +55,19 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
"strings to load multiple datasets."
)
if isinstance(cfg.dataset_repo_id, str) and cfg.env.name not in cfg.dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({cfg.dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
if cfg.env.name != "dora":
if isinstance(cfg.dataset_repo_id, str):
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
else:
dataset_repo_ids = cfg.dataset_repo_id # multiple datasets
for dataset_repo_id in dataset_repo_ids:
if cfg.env.name not in dataset_repo_id:
logging.warning(
f"There might be a mismatch between your training dataset ({dataset_repo_id=}) and your "
f"environment ({cfg.env.name=})."
)
resolve_delta_timestamps(cfg)

View File

@@ -25,6 +25,13 @@ class ACTConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and 'output_shapes`.
Notes on the inputs and outputs:
- At least one key starting with "observation.image is required as an input.
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
views. Right now we only support all images having the same shape.
- May optionally work without an "observation.state" key for the proprioceptive robot state.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
@@ -33,15 +40,15 @@ class ACTConfig:
This should be no greater than the chunk size. For example, if the chunk size size 100, you may
set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the
environment, and throws the other 50 out.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.images.top" refers to an input from the
"top" camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesn't include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesn't include batch dimension or temporal dimension.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -200,25 +200,29 @@ class ACT(nn.Module):
self.config = config
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
self.use_input_state = "observation.state" in config.input_shapes
if self.config.use_vae:
self.vae_encoder = ACTEncoder(config)
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
# Projection layer for joint-space configuration to hidden dimension.
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
if self.use_input_state:
self.vae_encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
# Projection layer for action (joint-space target) to hidden dimension.
self.vae_encoder_action_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
config.output_shapes["action"][0], config.dim_model
)
self.latent_dim = config.latent_dim
# Projection layer from the VAE encoder's output to the latent distribution's parameter space.
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, self.latent_dim * 2)
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
# dimension.
num_input_token_encoder = 1 + config.chunk_size
if self.use_input_state:
num_input_token_encoder += 1
self.register_buffer(
"vae_encoder_pos_enc",
create_sinusoidal_pos_embedding(1 + 1 + config.chunk_size, config.dim_model).unsqueeze(0),
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
)
# Backbone for image feature extraction.
@@ -238,15 +242,17 @@ class ACT(nn.Module):
# Transformer encoder input projections. The tokens will be structured like
# [latent, robot_state, image_feature_map_pixels].
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
if self.use_input_state:
self.encoder_robot_state_input_proj = nn.Linear(
config.input_shapes["observation.state"][0], config.dim_model
)
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
self.encoder_img_feat_input_proj = nn.Conv2d(
backbone_model.fc.in_features, config.dim_model, kernel_size=1
)
# Transformer encoder positional embeddings.
self.encoder_robot_and_latent_pos_embed = nn.Embedding(2, config.dim_model)
num_input_token_decoder = 2 if self.use_input_state else 1
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
# Transformer decoder.
@@ -285,7 +291,7 @@ class ACT(nn.Module):
"action" in batch
), "actions must be provided when using the variational objective in training mode."
batch_size = batch["observation.state"].shape[0]
batch_size = batch["observation.images"].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and "action" in batch:
@@ -293,11 +299,16 @@ class ACT(nn.Module):
cls_embed = einops.repeat(
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]).unsqueeze(
1
) # (B, 1, D)
if self.use_input_state:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
vae_encoder_input = torch.cat([cls_embed, robot_state_embed, action_embed], axis=1) # (B, S+2, D)
if self.use_input_state:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
# Prepare fixed positional embedding.
# Note: detach() shouldn't be necessary but leaving it the same as the original code just in case.
@@ -308,16 +319,17 @@ class ACT(nn.Module):
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
)[0] # select the class token, with shape (B, D)
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
mu = latent_pdf_params[:, : self.latent_dim]
mu = latent_pdf_params[:, : self.config.latent_dim]
# This is 2log(sigma). Done this way to match the original implementation.
log_sigma_x2 = latent_pdf_params[:, self.latent_dim :]
log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :]
# Sample the latent with the reparameterization trick.
latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu)
else:
# When not using the VAE encoder, we set the latent to be all zeros.
mu = log_sigma_x2 = None
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
batch["observation.state"].device
)
@@ -326,8 +338,10 @@ class ACT(nn.Module):
all_cam_features = []
all_cam_pos_embeds = []
images = batch["observation.images"]
for cam_index in range(images.shape[-4]):
cam_features = self.backbone(images[:, cam_index])["feature_map"]
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
all_cam_features.append(cam_features)
@@ -337,13 +351,15 @@ class ACT(nn.Module):
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
# Get positional embeddings for robot state and latent.
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
if self.use_input_state:
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
# Stack encoder input and positional embeddings moving to (S, B, C).
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
encoder_in = torch.cat(
[
torch.stack([latent_embed, robot_state_embed], axis=0),
torch.stack(encoder_in_feats, axis=0),
einops.rearrange(encoder_in, "b c h w -> (h w) b c"),
]
)
@@ -357,6 +373,7 @@ class ACT(nn.Module):
# Forward pass through the transformer modules.
encoder_out = self.encoder(encoder_in, pos_embed=pos_embed)
# TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer
decoder_in = torch.zeros(
(self.config.chunk_size, batch_size, self.config.dim_model),
dtype=pos_embed.dtype,

View File

@@ -26,21 +26,26 @@ class DiffusionConfig:
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
Those are: `input_shapes` and `output_shapes`.
Notes on the inputs and outputs:
- "observation.state" is required as an input key.
- A key starting with "observation.image is required as an input.
- "action" is required as an output key.
Args:
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
current step and additional steps going back).
horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`.
n_action_steps: The number of action steps to run in the environment for one invocation of the policy.
See `DiffusionPolicy.select_action` for more details.
input_shapes: A dictionary defining the shapes of the input data for the policy.
The key represents the input data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "observation.image" refers to an input from
a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution.
Importantly, shapes doesnt include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy.
The key represents the output data name, and the value is a list indicating the dimensions
of the corresponding data. For example, "action" refers to an output shape of [14], indicating
14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -31,6 +31,15 @@ class TDMPCConfig:
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
action repeats in Q-learning or ask your favorite chatbot)
horizon: Horizon for model predictive control.
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
the input data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
include batch dimension or temporal dimension.
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
the output data name, and the value is a list indicating the dimensions of the corresponding data.
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a

View File

@@ -0,0 +1,13 @@
# @package _global_
fps: 30
env:
name: dora
task: DoraAloha-v0
state_dim: 14
action_dim: 14
fps: ${fps}
episode_length: 400
gym:
fps: ${fps}

View File

@@ -0,0 +1,115 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -0,0 +1,111 @@
# @package _global_
# Use `act_real_no_state.yaml` to train on real-world Aloha/Aloha2 datasets when cameras are moving (e.g. wrist cameras)
# Compared to `act_real.yaml`, it is camera only and does not use the state as input which is vector of robot joint positions.
# We validated experimentaly that not using state reaches better success rate. Our hypothesis is that `act_real.yaml` might
# overfits to the state, because the images are more complex to learn from since they are moving.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real_no_state \
# env=dora_aloha_real
# ```
seed: 1000
dataset_repo_id: lerobot/aloha_static_vinh_cup
override_dataset_stats:
observation.images.cam_right_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_left_wrist:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.cam_low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 80000
online_steps: 0
eval_freq: -1
save_freq: 10000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
eval:
n_episodes: 50
batch_size: 50
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.cam_right_wrist: [3, 480, 640]
observation.images.cam_left_wrist: [3, 480, 640]
observation.images.cam_high: [3, 480, 640]
observation.images.cam_low: [3, 480, 640]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.cam_right_wrist: mean_std
observation.images.cam_left_wrist: mean_std
observation.images.cam_high: mean_std
observation.images.cam_low: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0