From df23672bcda471ef51aa14a6fa6368c1a2cd5316 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Tue, 16 Jul 2024 13:20:58 +0000 Subject: [PATCH] WIP: 2024_07_16_vqbet_koch_pick_place_lego_simple_v2 --- lerobot/configs/policy/vqbet_koch_real.yaml | 102 ++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 lerobot/configs/policy/vqbet_koch_real.yaml diff --git a/lerobot/configs/policy/vqbet_koch_real.yaml b/lerobot/configs/policy/vqbet_koch_real.yaml new file mode 100644 index 00000000..42a9956a --- /dev/null +++ b/lerobot/configs/policy/vqbet_koch_real.yaml @@ -0,0 +1,102 @@ +# @package _global_ + +# Defaults for training for the PushT dataset. + +seed: 100000 +dataset_repo_id: lerobot/koch_pick_place_lego + +override_dataset_stats: + observation.images.laptop: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + observation.images.phone: + # stats from imagenet, since we use a pretrained vision model + mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1) + std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1) + +training: + offline_steps: 80000 + online_steps: 0 + eval_freq: -1 + save_freq: 10000 + save_checkpoint: true + + batch_size: 8 + grad_clip_norm: 10 + lr: 1.0e-4 + lr_scheduler: cosine + lr_warmup_steps: 2000 + adam_betas: [0.95, 0.999] + adam_eps: 1.0e-8 + adam_weight_decay: 1.0e-6 + online_steps_between_rollouts: 1 + + # VQ-BeT specific + vqvae_lr: 1.0e-3 + n_vqvae_training_steps: 20000 + bet_weight_decay: 2e-4 + bet_learning_rate: 5.5e-5 + bet_betas: [0.9, 0.999] + + delta_timestamps: + observation.images.laptop: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.images.phone: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]" + action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]" + +eval: + n_episodes: 50 + batch_size: 50 + +policy: + name: vqbet + + # Input / output structure. + n_obs_steps: 5 + n_action_pred_token: 7 + action_chunk_size: 5 + + input_shapes: + # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.images.laptop: [3, 480, 640] + observation.images.phone: [3, 480, 640] + observation.state: ["${env.state_dim}"] + output_shapes: + action: ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.laptop: mean_std + observation.images.phone: mean_std + observation.state: min_max + output_normalization_modes: + action: min_max + + # Architecture / modeling. + # Vision backbone. + vision_backbone: resnet18 + pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1 + crop_is_random: False + spatial_softmax_num_keypoints: 512 + use_group_norm: False + crop_shape: null + # VQ-VAE + n_vqvae_training_steps: ${training.n_vqvae_training_steps} + vqvae_n_embed: 16 + vqvae_embedding_dim: 256 + vqvae_enc_hidden_dim: 128 + # VQ-BeT + gpt_block_size: 500 + gpt_input_dim: 512 + gpt_output_dim: 512 + gpt_n_layer: 8 + gpt_n_head: 8 + gpt_hidden_dim: 512 + dropout: 0.1 + mlp_hidden_dim: 1024 + offset_loss_weight: 10000. + primary_code_loss_weight: 5.0 + secondary_code_loss_weight: 0.5 + bet_softmax_temperature: 0.1 + sequentially_select: False