Add aloha2_real, Add act_real, Fix vae=false, Add support for no state
This commit is contained in:
@@ -26,10 +26,11 @@ class ACTConfig:
|
|||||||
Those are: `input_shapes` and 'output_shapes`.
|
Those are: `input_shapes` and 'output_shapes`.
|
||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
|
- "observation.state" is required as an input key.
|
||||||
- At least one key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
- If there are multiple keys beginning with "observation.images." they are treated as multiple camera
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
views. Right now we only support all images having the same shape.
|
views.
|
||||||
- May optionally work without an "observation.state" key for the proprioceptive robot state.
|
Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -200,12 +200,13 @@ class ACT(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
# BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence].
|
||||||
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
# The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]).
|
||||||
self.use_input_state = "observation.state" in config.input_shapes
|
self.has_state = "observation.state" in config.input_shapes
|
||||||
|
self.latent_dim = config.latent_dim
|
||||||
if self.config.use_vae:
|
if self.config.use_vae:
|
||||||
self.vae_encoder = ACTEncoder(config)
|
self.vae_encoder = ACTEncoder(config)
|
||||||
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model)
|
||||||
# Projection layer for joint-space configuration to hidden dimension.
|
# Projection layer for joint-space configuration to hidden dimension.
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
self.vae_encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
@@ -217,9 +218,7 @@ class ACT(nn.Module):
|
|||||||
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2)
|
||||||
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
# Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch
|
||||||
# dimension.
|
# dimension.
|
||||||
num_input_token_encoder = 1 + config.chunk_size
|
num_input_token_encoder = 1 + 1 + config.chunk_size if self.has_state else 1 + config.chunk_size
|
||||||
if self.use_input_state:
|
|
||||||
num_input_token_encoder += 1
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"vae_encoder_pos_enc",
|
"vae_encoder_pos_enc",
|
||||||
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0),
|
||||||
@@ -242,16 +241,16 @@ class ACT(nn.Module):
|
|||||||
|
|
||||||
# Transformer encoder input projections. The tokens will be structured like
|
# Transformer encoder input projections. The tokens will be structured like
|
||||||
# [latent, robot_state, image_feature_map_pixels].
|
# [latent, robot_state, image_feature_map_pixels].
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
self.encoder_robot_state_input_proj = nn.Linear(
|
self.encoder_robot_state_input_proj = nn.Linear(
|
||||||
config.input_shapes["observation.state"][0], config.dim_model
|
config.input_shapes["observation.state"][0], config.dim_model
|
||||||
)
|
)
|
||||||
self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model)
|
self.encoder_latent_input_proj = nn.Linear(self.latent_dim, config.dim_model)
|
||||||
self.encoder_img_feat_input_proj = nn.Conv2d(
|
self.encoder_img_feat_input_proj = nn.Conv2d(
|
||||||
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
backbone_model.fc.in_features, config.dim_model, kernel_size=1
|
||||||
)
|
)
|
||||||
# Transformer encoder positional embeddings.
|
# Transformer encoder positional embeddings.
|
||||||
num_input_token_decoder = 2 if self.use_input_state else 1
|
num_input_token_decoder = 2 if self.has_state else 1
|
||||||
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
self.encoder_robot_and_latent_pos_embed = nn.Embedding(num_input_token_decoder, config.dim_model)
|
||||||
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
|
||||||
|
|
||||||
@@ -299,12 +298,12 @@ class ACT(nn.Module):
|
|||||||
cls_embed = einops.repeat(
|
cls_embed = einops.repeat(
|
||||||
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
|
||||||
) # (B, 1, D)
|
) # (B, 1, D)
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
|
||||||
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
|
||||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||||
|
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||||
else:
|
else:
|
||||||
vae_encoder_input = [cls_embed, action_embed]
|
vae_encoder_input = [cls_embed, action_embed]
|
||||||
@@ -329,7 +328,7 @@ class ACT(nn.Module):
|
|||||||
# When not using the VAE encoder, we set the latent to be all zeros.
|
# When not using the VAE encoder, we set the latent to be all zeros.
|
||||||
mu = log_sigma_x2 = None
|
mu = log_sigma_x2 = None
|
||||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer
|
||||||
latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to(
|
latent_sample = torch.zeros([batch_size, self.latent_dim], dtype=torch.float32).to(
|
||||||
batch["observation.state"].device
|
batch["observation.state"].device
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,12 +350,12 @@ class ACT(nn.Module):
|
|||||||
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
cam_pos_embed = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||||
|
|
||||||
# Get positional embeddings for robot state and latent.
|
# Get positional embeddings for robot state and latent.
|
||||||
if self.use_input_state:
|
if self.has_state:
|
||||||
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
robot_state_embed = self.encoder_robot_state_input_proj(batch["observation.state"]) # (B, C)
|
||||||
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
latent_embed = self.encoder_latent_input_proj(latent_sample) # (B, C)
|
||||||
|
|
||||||
# Stack encoder input and positional embeddings moving to (S, B, C).
|
# Stack encoder input and positional embeddings moving to (S, B, C).
|
||||||
encoder_in_feats = [latent_embed, robot_state_embed] if self.use_input_state else [latent_embed]
|
encoder_in_feats = [latent_embed, robot_state_embed] if self.has_state else [latent_embed]
|
||||||
encoder_in = torch.cat(
|
encoder_in = torch.cat(
|
||||||
[
|
[
|
||||||
torch.stack(encoder_in_feats, axis=0),
|
torch.stack(encoder_in_feats, axis=0),
|
||||||
|
|||||||
@@ -28,7 +28,10 @@ class DiffusionConfig:
|
|||||||
|
|
||||||
Notes on the inputs and outputs:
|
Notes on the inputs and outputs:
|
||||||
- "observation.state" is required as an input key.
|
- "observation.state" is required as an input key.
|
||||||
- A key starting with "observation.image is required as an input.
|
- At least one key starting with "observation.image is required as an input.
|
||||||
|
- If there are multiple keys beginning with "observation.image" they are treated as multiple camera
|
||||||
|
views.
|
||||||
|
Right now we only support all images having the same shape.
|
||||||
- "action" is required as an output key.
|
- "action" is required as an output key.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
13
lerobot/configs/env/aloha2_real.yaml
vendored
Normal file
13
lerobot/configs/env/aloha2_real.yaml
vendored
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
fps: 30
|
||||||
|
|
||||||
|
env:
|
||||||
|
name: dora
|
||||||
|
task: DoraAloha2-v0
|
||||||
|
state_dim: 14
|
||||||
|
action_dim: 14
|
||||||
|
fps: ${fps}
|
||||||
|
episode_length: 400
|
||||||
|
gym:
|
||||||
|
fps: ${fps}
|
||||||
@@ -1,21 +1,7 @@
|
|||||||
# @package _global_
|
# @package _global_
|
||||||
|
|
||||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
|
||||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. cam_right_wrist, cam_left_wrist, images,
|
|
||||||
# cam_low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
|
||||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
|
||||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
|
||||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
|
||||||
#
|
|
||||||
# Example of usage for training:
|
|
||||||
# ```bash
|
|
||||||
# python lerobot/scripts/train.py \
|
|
||||||
# policy=act_real \
|
|
||||||
# env=dora_aloha_real
|
|
||||||
# ```
|
|
||||||
|
|
||||||
seed: 1000
|
seed: 1000
|
||||||
dataset_repo_id: lerobot/aloha_static_vinh_cup
|
dataset_repo_id: cadene/aloha_v2_static_dora_test
|
||||||
|
|
||||||
override_dataset_stats:
|
override_dataset_stats:
|
||||||
observation.images.cam_right_wrist:
|
observation.images.cam_right_wrist:
|
||||||
@@ -41,7 +27,7 @@ training:
|
|||||||
eval_freq: -1
|
eval_freq: -1
|
||||||
save_freq: 10000
|
save_freq: 10000
|
||||||
log_freq: 100
|
log_freq: 100
|
||||||
save_checkpoint: true
|
save_model: true
|
||||||
|
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
lr: 1e-5
|
lr: 1e-5
|
||||||
|
|||||||
Reference in New Issue
Block a user