This commit is contained in:
Remi Cadene
2024-07-02 21:15:48 +02:00
parent 73f46bac56
commit bbf617fd92
37 changed files with 4562 additions and 28 deletions

View File

@@ -0,0 +1,89 @@
# Using `lerobot` on a real world arm
In this example, we'll be using `lerobot` on a real world arm to:
- record a dataset in the `lerobot` format
- (soon) train a policy on it
- (soon) run the policy in the real-world
## Which robotic arm to use
In this example we're using the [open-source low-cost arm from Alexander Koch](https://github.com/AlexanderKoch-Koch/low_cost_robot) in the specific setup of:
- having 6 servos per arm, i.e. using the elbow-to-wrist extension
- adding two cameras around it, one on top and one in the front
- having a teleoperation arm as well (build the leader and the follower arms in A. Koch repo, both with elbow-to-wrist extensions)
I'm using these cameras (but the setup should not be sensitive to the exact cameras you're using):
- C922 Pro Stream Webcam
- Intel(R) RealSense D455 (using only the RGB input)
In general, this example should be very easily extendable to any type of arm using Dynamixel servos with at least one camera by changing a couple of configuration in the gym env.
## Install the example
Follow these steps:
- install `lerobot`
- install the Dynamixel-sdk: `pip install dynamixel-sdk`
## Usage
### 0 - record examples
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
```
DATA_DIR='./data' python record_training_data.py \
--repo-id=thomwolf/blue_red_sort \
--num-episodes=50 \
--num-frames=400
```
TODO:
- various length episodes
- being able to drop episodes
- checking uploading to the hub
### 1 - visualize the dataset
Use the standard dataset visualization script pointing it to the right folder:
```
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
--repo-id thomwolf/blue_red_sort \
--episode-index 0
```
### 2 - Train a policy
From the example directory let's run this command to train a model using ACT
```
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
device=cuda \
hydra.searchpath=[file://./train_config/] \
hydra.run.dir=./outputs/train/blue_red_sort \
dataset_repo_id=thomwolf/blue_red_sort \
env=gym_real_world \
policy=act_real_world \
wandb.enable=false
```
### 3 - Evaluate the policy in the real world
From the example directory let's run this command to evaluate our policy.
The configuration for running the policy is in the checkpoint of the model.
You can override parameters as follow:
```
python run_policy.py \
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
env.episode_length=1000
```
## Convert a hdf5 dataset recorded with the original ACT repo
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
```
python ./lerobot/scripts/push_dataset_to_hub.py
```

View File

@@ -0,0 +1,840 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from safetensors.torch import load_file, save_file\n",
"from pprint import pprint"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n",
"converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n",
"\n",
"comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n",
"comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n",
"comparison_config_json_path = comparison_main_path + \"config.json\"\n",
"comparison_config_yaml_path = comparison_main_path + \"config.yaml\""
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"b = load_file(comparison_safetensor_path)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.action_head.bias',\n",
" 'model.action_head.weight',\n",
" 'model.backbone.bn1.bias',\n",
" 'model.backbone.bn1.running_mean',\n",
" 'model.backbone.bn1.running_var',\n",
" 'model.backbone.bn1.weight',\n",
" 'model.backbone.conv1.weight',\n",
" 'model.backbone.layer1.0.bn1.bias',\n",
" 'model.backbone.layer1.0.bn1.running_mean',\n",
" 'model.backbone.layer1.0.bn1.running_var',\n",
" 'model.backbone.layer1.0.bn1.weight',\n",
" 'model.backbone.layer1.0.bn2.bias',\n",
" 'model.backbone.layer1.0.bn2.running_mean',\n",
" 'model.backbone.layer1.0.bn2.running_var',\n",
" 'model.backbone.layer1.0.bn2.weight',\n",
" 'model.backbone.layer1.0.conv1.weight',\n",
" 'model.backbone.layer1.0.conv2.weight',\n",
" 'model.backbone.layer1.1.bn1.bias',\n",
" 'model.backbone.layer1.1.bn1.running_mean',\n",
" 'model.backbone.layer1.1.bn1.running_var',\n",
" 'model.backbone.layer1.1.bn1.weight',\n",
" 'model.backbone.layer1.1.bn2.bias',\n",
" 'model.backbone.layer1.1.bn2.running_mean',\n",
" 'model.backbone.layer1.1.bn2.running_var',\n",
" 'model.backbone.layer1.1.bn2.weight',\n",
" 'model.backbone.layer1.1.conv1.weight',\n",
" 'model.backbone.layer1.1.conv2.weight',\n",
" 'model.backbone.layer2.0.bn1.bias',\n",
" 'model.backbone.layer2.0.bn1.running_mean',\n",
" 'model.backbone.layer2.0.bn1.running_var',\n",
" 'model.backbone.layer2.0.bn1.weight',\n",
" 'model.backbone.layer2.0.bn2.bias',\n",
" 'model.backbone.layer2.0.bn2.running_mean',\n",
" 'model.backbone.layer2.0.bn2.running_var',\n",
" 'model.backbone.layer2.0.bn2.weight',\n",
" 'model.backbone.layer2.0.conv1.weight',\n",
" 'model.backbone.layer2.0.conv2.weight',\n",
" 'model.backbone.layer2.0.downsample.0.weight',\n",
" 'model.backbone.layer2.0.downsample.1.bias',\n",
" 'model.backbone.layer2.0.downsample.1.running_mean',\n",
" 'model.backbone.layer2.0.downsample.1.running_var',\n",
" 'model.backbone.layer2.0.downsample.1.weight',\n",
" 'model.backbone.layer2.1.bn1.bias',\n",
" 'model.backbone.layer2.1.bn1.running_mean',\n",
" 'model.backbone.layer2.1.bn1.running_var',\n",
" 'model.backbone.layer2.1.bn1.weight',\n",
" 'model.backbone.layer2.1.bn2.bias',\n",
" 'model.backbone.layer2.1.bn2.running_mean',\n",
" 'model.backbone.layer2.1.bn2.running_var',\n",
" 'model.backbone.layer2.1.bn2.weight',\n",
" 'model.backbone.layer2.1.conv1.weight',\n",
" 'model.backbone.layer2.1.conv2.weight',\n",
" 'model.backbone.layer3.0.bn1.bias',\n",
" 'model.backbone.layer3.0.bn1.running_mean',\n",
" 'model.backbone.layer3.0.bn1.running_var',\n",
" 'model.backbone.layer3.0.bn1.weight',\n",
" 'model.backbone.layer3.0.bn2.bias',\n",
" 'model.backbone.layer3.0.bn2.running_mean',\n",
" 'model.backbone.layer3.0.bn2.running_var',\n",
" 'model.backbone.layer3.0.bn2.weight',\n",
" 'model.backbone.layer3.0.conv1.weight',\n",
" 'model.backbone.layer3.0.conv2.weight',\n",
" 'model.backbone.layer3.0.downsample.0.weight',\n",
" 'model.backbone.layer3.0.downsample.1.bias',\n",
" 'model.backbone.layer3.0.downsample.1.running_mean',\n",
" 'model.backbone.layer3.0.downsample.1.running_var',\n",
" 'model.backbone.layer3.0.downsample.1.weight',\n",
" 'model.backbone.layer3.1.bn1.bias',\n",
" 'model.backbone.layer3.1.bn1.running_mean',\n",
" 'model.backbone.layer3.1.bn1.running_var',\n",
" 'model.backbone.layer3.1.bn1.weight',\n",
" 'model.backbone.layer3.1.bn2.bias',\n",
" 'model.backbone.layer3.1.bn2.running_mean',\n",
" 'model.backbone.layer3.1.bn2.running_var',\n",
" 'model.backbone.layer3.1.bn2.weight',\n",
" 'model.backbone.layer3.1.conv1.weight',\n",
" 'model.backbone.layer3.1.conv2.weight',\n",
" 'model.backbone.layer4.0.bn1.bias',\n",
" 'model.backbone.layer4.0.bn1.running_mean',\n",
" 'model.backbone.layer4.0.bn1.running_var',\n",
" 'model.backbone.layer4.0.bn1.weight',\n",
" 'model.backbone.layer4.0.bn2.bias',\n",
" 'model.backbone.layer4.0.bn2.running_mean',\n",
" 'model.backbone.layer4.0.bn2.running_var',\n",
" 'model.backbone.layer4.0.bn2.weight',\n",
" 'model.backbone.layer4.0.conv1.weight',\n",
" 'model.backbone.layer4.0.conv2.weight',\n",
" 'model.backbone.layer4.0.downsample.0.weight',\n",
" 'model.backbone.layer4.0.downsample.1.bias',\n",
" 'model.backbone.layer4.0.downsample.1.running_mean',\n",
" 'model.backbone.layer4.0.downsample.1.running_var',\n",
" 'model.backbone.layer4.0.downsample.1.weight',\n",
" 'model.backbone.layer4.1.bn1.bias',\n",
" 'model.backbone.layer4.1.bn1.running_mean',\n",
" 'model.backbone.layer4.1.bn1.running_var',\n",
" 'model.backbone.layer4.1.bn1.weight',\n",
" 'model.backbone.layer4.1.bn2.bias',\n",
" 'model.backbone.layer4.1.bn2.running_mean',\n",
" 'model.backbone.layer4.1.bn2.running_var',\n",
" 'model.backbone.layer4.1.bn2.weight',\n",
" 'model.backbone.layer4.1.conv1.weight',\n",
" 'model.backbone.layer4.1.conv2.weight',\n",
" 'model.decoder.layers.0.linear1.bias',\n",
" 'model.decoder.layers.0.linear1.weight',\n",
" 'model.decoder.layers.0.linear2.bias',\n",
" 'model.decoder.layers.0.linear2.weight',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.decoder.layers.0.norm1.bias',\n",
" 'model.decoder.layers.0.norm1.weight',\n",
" 'model.decoder.layers.0.norm2.bias',\n",
" 'model.decoder.layers.0.norm2.weight',\n",
" 'model.decoder.layers.0.norm3.bias',\n",
" 'model.decoder.layers.0.norm3.weight',\n",
" 'model.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.decoder_pos_embed.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder_img_feat_input_proj.bias',\n",
" 'model.encoder_img_feat_input_proj.weight',\n",
" 'model.encoder_latent_input_proj.bias',\n",
" 'model.encoder_latent_input_proj.weight',\n",
" 'model.encoder_robot_and_latent_pos_embed.weight',\n",
" 'model.encoder_robot_state_input_proj.bias',\n",
" 'model.encoder_robot_state_input_proj.weight',\n",
" 'model.vae_encoder.layers.0.linear1.bias',\n",
" 'model.vae_encoder.layers.0.linear1.weight',\n",
" 'model.vae_encoder.layers.0.linear2.bias',\n",
" 'model.vae_encoder.layers.0.linear2.weight',\n",
" 'model.vae_encoder.layers.0.norm1.bias',\n",
" 'model.vae_encoder.layers.0.norm1.weight',\n",
" 'model.vae_encoder.layers.0.norm2.bias',\n",
" 'model.vae_encoder.layers.0.norm2.weight',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.1.linear1.bias',\n",
" 'model.vae_encoder.layers.1.linear1.weight',\n",
" 'model.vae_encoder.layers.1.linear2.bias',\n",
" 'model.vae_encoder.layers.1.linear2.weight',\n",
" 'model.vae_encoder.layers.1.norm1.bias',\n",
" 'model.vae_encoder.layers.1.norm1.weight',\n",
" 'model.vae_encoder.layers.1.norm2.bias',\n",
" 'model.vae_encoder.layers.1.norm2.weight',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.2.linear1.bias',\n",
" 'model.vae_encoder.layers.2.linear1.weight',\n",
" 'model.vae_encoder.layers.2.linear2.bias',\n",
" 'model.vae_encoder.layers.2.linear2.weight',\n",
" 'model.vae_encoder.layers.2.norm1.bias',\n",
" 'model.vae_encoder.layers.2.norm1.weight',\n",
" 'model.vae_encoder.layers.2.norm2.bias',\n",
" 'model.vae_encoder.layers.2.norm2.weight',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.vae_encoder.layers.3.linear1.bias',\n",
" 'model.vae_encoder.layers.3.linear1.weight',\n",
" 'model.vae_encoder.layers.3.linear2.bias',\n",
" 'model.vae_encoder.layers.3.linear2.weight',\n",
" 'model.vae_encoder.layers.3.norm1.bias',\n",
" 'model.vae_encoder.layers.3.norm1.weight',\n",
" 'model.vae_encoder.layers.3.norm2.bias',\n",
" 'model.vae_encoder.layers.3.norm2.weight',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.vae_encoder_action_input_proj.bias',\n",
" 'model.vae_encoder_action_input_proj.weight',\n",
" 'model.vae_encoder_cls_embed.weight',\n",
" 'model.vae_encoder_latent_output_proj.bias',\n",
" 'model.vae_encoder_latent_output_proj.weight',\n",
" 'model.vae_encoder_pos_enc',\n",
" 'model.vae_encoder_robot_state_input_proj.bias',\n",
" 'model.vae_encoder_robot_state_input_proj.weight',\n",
" 'normalize_inputs.buffer_observation_images_front.mean',\n",
" 'normalize_inputs.buffer_observation_images_front.std',\n",
" 'normalize_inputs.buffer_observation_images_top.mean',\n",
" 'normalize_inputs.buffer_observation_images_top.std',\n",
" 'normalize_inputs.buffer_observation_state.mean',\n",
" 'normalize_inputs.buffer_observation_state.std',\n",
" 'normalize_targets.buffer_action.mean',\n",
" 'normalize_targets.buffer_action.std',\n",
" 'unnormalize_outputs.buffer_action.mean',\n",
" 'unnormalize_outputs.buffer_action.std']\n"
]
}
],
"source": [
"dest = list(b.keys())\n",
"pprint(dest)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['model.pos_table',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.0.linear1.weight',\n",
" 'model.transformer.encoder.layers.0.linear1.bias',\n",
" 'model.transformer.encoder.layers.0.linear2.weight',\n",
" 'model.transformer.encoder.layers.0.linear2.bias',\n",
" 'model.transformer.encoder.layers.0.norm1.weight',\n",
" 'model.transformer.encoder.layers.0.norm1.bias',\n",
" 'model.transformer.encoder.layers.0.norm2.weight',\n",
" 'model.transformer.encoder.layers.0.norm2.bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.1.linear1.weight',\n",
" 'model.transformer.encoder.layers.1.linear1.bias',\n",
" 'model.transformer.encoder.layers.1.linear2.weight',\n",
" 'model.transformer.encoder.layers.1.linear2.bias',\n",
" 'model.transformer.encoder.layers.1.norm1.weight',\n",
" 'model.transformer.encoder.layers.1.norm1.bias',\n",
" 'model.transformer.encoder.layers.1.norm2.weight',\n",
" 'model.transformer.encoder.layers.1.norm2.bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.2.linear1.weight',\n",
" 'model.transformer.encoder.layers.2.linear1.bias',\n",
" 'model.transformer.encoder.layers.2.linear2.weight',\n",
" 'model.transformer.encoder.layers.2.linear2.bias',\n",
" 'model.transformer.encoder.layers.2.norm1.weight',\n",
" 'model.transformer.encoder.layers.2.norm1.bias',\n",
" 'model.transformer.encoder.layers.2.norm2.weight',\n",
" 'model.transformer.encoder.layers.2.norm2.bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.encoder.layers.3.linear1.weight',\n",
" 'model.transformer.encoder.layers.3.linear1.bias',\n",
" 'model.transformer.encoder.layers.3.linear2.weight',\n",
" 'model.transformer.encoder.layers.3.linear2.bias',\n",
" 'model.transformer.encoder.layers.3.norm1.weight',\n",
" 'model.transformer.encoder.layers.3.norm1.bias',\n",
" 'model.transformer.encoder.layers.3.norm2.weight',\n",
" 'model.transformer.encoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.0.linear1.weight',\n",
" 'model.transformer.decoder.layers.0.linear1.bias',\n",
" 'model.transformer.decoder.layers.0.linear2.weight',\n",
" 'model.transformer.decoder.layers.0.linear2.bias',\n",
" 'model.transformer.decoder.layers.0.norm1.weight',\n",
" 'model.transformer.decoder.layers.0.norm1.bias',\n",
" 'model.transformer.decoder.layers.0.norm2.weight',\n",
" 'model.transformer.decoder.layers.0.norm2.bias',\n",
" 'model.transformer.decoder.layers.0.norm3.weight',\n",
" 'model.transformer.decoder.layers.0.norm3.bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.1.linear1.weight',\n",
" 'model.transformer.decoder.layers.1.linear1.bias',\n",
" 'model.transformer.decoder.layers.1.linear2.weight',\n",
" 'model.transformer.decoder.layers.1.linear2.bias',\n",
" 'model.transformer.decoder.layers.1.norm1.weight',\n",
" 'model.transformer.decoder.layers.1.norm1.bias',\n",
" 'model.transformer.decoder.layers.1.norm2.weight',\n",
" 'model.transformer.decoder.layers.1.norm2.bias',\n",
" 'model.transformer.decoder.layers.1.norm3.weight',\n",
" 'model.transformer.decoder.layers.1.norm3.bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.2.linear1.weight',\n",
" 'model.transformer.decoder.layers.2.linear1.bias',\n",
" 'model.transformer.decoder.layers.2.linear2.weight',\n",
" 'model.transformer.decoder.layers.2.linear2.bias',\n",
" 'model.transformer.decoder.layers.2.norm1.weight',\n",
" 'model.transformer.decoder.layers.2.norm1.bias',\n",
" 'model.transformer.decoder.layers.2.norm2.weight',\n",
" 'model.transformer.decoder.layers.2.norm2.bias',\n",
" 'model.transformer.decoder.layers.2.norm3.weight',\n",
" 'model.transformer.decoder.layers.2.norm3.bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.3.linear1.weight',\n",
" 'model.transformer.decoder.layers.3.linear1.bias',\n",
" 'model.transformer.decoder.layers.3.linear2.weight',\n",
" 'model.transformer.decoder.layers.3.linear2.bias',\n",
" 'model.transformer.decoder.layers.3.norm1.weight',\n",
" 'model.transformer.decoder.layers.3.norm1.bias',\n",
" 'model.transformer.decoder.layers.3.norm2.weight',\n",
" 'model.transformer.decoder.layers.3.norm2.bias',\n",
" 'model.transformer.decoder.layers.3.norm3.weight',\n",
" 'model.transformer.decoder.layers.3.norm3.bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.4.linear1.weight',\n",
" 'model.transformer.decoder.layers.4.linear1.bias',\n",
" 'model.transformer.decoder.layers.4.linear2.weight',\n",
" 'model.transformer.decoder.layers.4.linear2.bias',\n",
" 'model.transformer.decoder.layers.4.norm1.weight',\n",
" 'model.transformer.decoder.layers.4.norm1.bias',\n",
" 'model.transformer.decoder.layers.4.norm2.weight',\n",
" 'model.transformer.decoder.layers.4.norm2.bias',\n",
" 'model.transformer.decoder.layers.4.norm3.weight',\n",
" 'model.transformer.decoder.layers.4.norm3.bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.5.linear1.weight',\n",
" 'model.transformer.decoder.layers.5.linear1.bias',\n",
" 'model.transformer.decoder.layers.5.linear2.weight',\n",
" 'model.transformer.decoder.layers.5.linear2.bias',\n",
" 'model.transformer.decoder.layers.5.norm1.weight',\n",
" 'model.transformer.decoder.layers.5.norm1.bias',\n",
" 'model.transformer.decoder.layers.5.norm2.weight',\n",
" 'model.transformer.decoder.layers.5.norm2.bias',\n",
" 'model.transformer.decoder.layers.5.norm3.weight',\n",
" 'model.transformer.decoder.layers.5.norm3.bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n",
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n",
" 'model.transformer.decoder.layers.6.linear1.weight',\n",
" 'model.transformer.decoder.layers.6.linear1.bias',\n",
" 'model.transformer.decoder.layers.6.linear2.weight',\n",
" 'model.transformer.decoder.layers.6.linear2.bias',\n",
" 'model.transformer.decoder.layers.6.norm1.weight',\n",
" 'model.transformer.decoder.layers.6.norm1.bias',\n",
" 'model.transformer.decoder.layers.6.norm2.weight',\n",
" 'model.transformer.decoder.layers.6.norm2.bias',\n",
" 'model.transformer.decoder.layers.6.norm3.weight',\n",
" 'model.transformer.decoder.layers.6.norm3.bias',\n",
" 'model.transformer.decoder.norm.weight',\n",
" 'model.transformer.decoder.norm.bias',\n",
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.0.linear1.weight',\n",
" 'model.encoder.layers.0.linear1.bias',\n",
" 'model.encoder.layers.0.linear2.weight',\n",
" 'model.encoder.layers.0.linear2.bias',\n",
" 'model.encoder.layers.0.norm1.weight',\n",
" 'model.encoder.layers.0.norm1.bias',\n",
" 'model.encoder.layers.0.norm2.weight',\n",
" 'model.encoder.layers.0.norm2.bias',\n",
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.1.linear1.weight',\n",
" 'model.encoder.layers.1.linear1.bias',\n",
" 'model.encoder.layers.1.linear2.weight',\n",
" 'model.encoder.layers.1.linear2.bias',\n",
" 'model.encoder.layers.1.norm1.weight',\n",
" 'model.encoder.layers.1.norm1.bias',\n",
" 'model.encoder.layers.1.norm2.weight',\n",
" 'model.encoder.layers.1.norm2.bias',\n",
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.2.linear1.weight',\n",
" 'model.encoder.layers.2.linear1.bias',\n",
" 'model.encoder.layers.2.linear2.weight',\n",
" 'model.encoder.layers.2.linear2.bias',\n",
" 'model.encoder.layers.2.norm1.weight',\n",
" 'model.encoder.layers.2.norm1.bias',\n",
" 'model.encoder.layers.2.norm2.weight',\n",
" 'model.encoder.layers.2.norm2.bias',\n",
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
" 'model.encoder.layers.3.linear1.weight',\n",
" 'model.encoder.layers.3.linear1.bias',\n",
" 'model.encoder.layers.3.linear2.weight',\n",
" 'model.encoder.layers.3.linear2.bias',\n",
" 'model.encoder.layers.3.norm1.weight',\n",
" 'model.encoder.layers.3.norm1.bias',\n",
" 'model.encoder.layers.3.norm2.weight',\n",
" 'model.encoder.layers.3.norm2.bias',\n",
" 'model.action_head.weight',\n",
" 'model.action_head.bias',\n",
" 'model.is_pad_head.weight',\n",
" 'model.is_pad_head.bias',\n",
" 'model.query_embed.weight',\n",
" 'model.input_proj.weight',\n",
" 'model.input_proj.bias',\n",
" 'model.backbones.0.0.body.conv1.weight',\n",
" 'model.backbones.0.0.body.bn1.weight',\n",
" 'model.backbones.0.0.body.bn1.bias',\n",
" 'model.backbones.0.0.body.bn1.running_mean',\n",
" 'model.backbones.0.0.body.bn1.running_var',\n",
" 'model.backbones.0.0.body.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer1.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer2.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer3.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n",
" 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n",
" 'model.backbones.0.0.body.layer4.1.conv2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.weight',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.bias',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n",
" 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n",
" 'model.input_proj_robot_state.weight',\n",
" 'model.input_proj_robot_state.bias',\n",
" 'model.cls_embed.weight',\n",
" 'model.encoder_action_proj.weight',\n",
" 'model.encoder_action_proj.bias',\n",
" 'model.encoder_joint_proj.weight',\n",
" 'model.encoder_joint_proj.bias',\n",
" 'model.latent_proj.weight',\n",
" 'model.latent_proj.bias',\n",
" 'model.latent_out_proj.weight',\n",
" 'model.latent_out_proj.bias',\n",
" 'model.additional_pos_embed.weight']\n"
]
}
],
"source": [
"orig = list(a.keys())\n",
"pprint(orig)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"a = torch.load(original_ckpt_path)\n",
"\n",
"to_remove_startswith = ['model.transformer.decoder.layers.1.',\n",
" 'model.transformer.decoder.layers.2.',\n",
" 'model.transformer.decoder.layers.3.',\n",
" 'model.transformer.decoder.layers.4.',\n",
" 'model.transformer.decoder.layers.5.',\n",
" 'model.transformer.decoder.layers.6.',\n",
" 'model.transformer.decoder.norm.',\n",
" 'model.is_pad_head']\n",
"\n",
"to_remove_in = ['num_batches_tracked',]\n",
"\n",
"conv = {}\n",
"\n",
"keys = list(a.keys())\n",
"for k in keys:\n",
" if any(k.startswith(tr) for tr in to_remove_startswith):\n",
" a.pop(k)\n",
" continue\n",
" if any(tr in k for tr in to_remove_in):\n",
" a.pop(k)\n",
" continue\n",
" if k.startswith('model.transformer.encoder.layers.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
" if k.startswith('model.encoder.layers.'):\n",
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
" if k.startswith('model.action_head.'):\n",
" conv[k] = a.pop(k)\n",
" if k.startswith('model.pos_table'):\n",
" conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n",
" if k.startswith('model.query_embed.'):\n",
" conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n",
" if k.startswith('model.input_proj.'):\n",
" conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.input_proj_robot_state.'):\n",
" conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n",
" if k.startswith('model.backbones.0.0.body.'):\n",
" conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n",
" if k.startswith('model.cls_embed.'):\n",
" conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n",
" if k.startswith('model.encoder_action_proj.'):\n",
" conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n",
" if k.startswith('model.encoder_joint_proj.'):\n",
" conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_proj.'):\n",
" conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n",
" if k.startswith('model.latent_out_proj.'):\n",
" conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n",
" if k.startswith('model.additional_pos_embed.'):\n",
" conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict()"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"for k, v in conv.items():\n",
" assert b[k].shape == v.shape\n",
" b[k] = v"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"save_file(b, converted_ckpt_path)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Now also copy the config files\n",
"import shutil\n",
"shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n",
"shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "lerobot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,8 @@
from gymnasium.envs.registration import register
register(
id="gym_real_world/RealEnv-v0",
entry_point="gym_real_world.gym_environment:RealEnv",
max_episode_steps=300,
nondeterministic=True,
)

View File

@@ -0,0 +1,363 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Dynamixel class to control the dynamixel servos
"""
from __future__ import annotations
import enum
import math
import os
from dataclasses import dataclass
import numpy as np
from dynamixel_sdk import * # Uses Dynamixel SDK library
def pos2pwm(pos: np.ndarray) -> np.ndarray:
"""
:param pos: numpy array of joint positions in range [-pi, pi]
:return: numpy array of pwm values in range [0, 4096]
"""
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm values in range [0, 4096]
:return: numpy array of joint positions in range [-pi, pi]
"""
return (pwm / 2048 - 1) * 3.14
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
"""
:param pwm: numpy array of pwm/s joint velocities
:return: numpy array of rad/s joint velocities
"""
return pwm * 3.14 / 2048
def vel2pwm(vel: np.ndarray) -> np.ndarray:
"""
:param vel: numpy array of rad/s joint velocities
:return: numpy array of pwm/s joint velocities
"""
return (vel * 2048 / 3.14).astype(np.int64)
class ReadAttribute(enum.Enum):
TEMPERATURE = 146
VOLTAGE = 145
VELOCITY = 128
POSITION = 132
CURRENT = 126
PWM = 124
HARDWARE_ERROR_STATUS = 70
HOMING_OFFSET = 20
BAUDRATE = 8
class OperatingMode(enum.Enum):
VELOCITY = 1
POSITION = 3
CURRENT_CONTROLLED_POSITION = 5
PWM = 16
UNKNOWN = -1
class Dynamixel:
ADDR_TORQUE_ENABLE = 64
ADDR_GOAL_POSITION = 116
ADDR_VELOCITY_LIMIT = 44
ADDR_GOAL_PWM = 100
OPERATING_MODE_ADDR = 11
POSITION_I = 82
POSITION_P = 84
ADDR_ID = 7
@dataclass
class Config:
def instantiate(self):
return Dynamixel(self)
baudrate: int = 57600
protocol_version: float = 2.0
device_name: str = "" # /dev/tty.usbserial-1120'
dynamixel_id: int = 1
def __init__(self, config: Config):
self.config = config
self.connect()
def connect(self):
if self.config.device_name == "":
for port_name in os.listdir("/dev"):
if "ttyUSB" in port_name or "ttyACM" in port_name:
self.config.device_name = "/dev/" + port_name
print(f"using device {self.config.device_name}")
self.portHandler = PortHandler(self.config.device_name)
# self.portHandler.LA
self.packetHandler = PacketHandler(self.config.protocol_version)
if not self.portHandler.openPort():
raise Exception(f"Failed to open port {self.config.device_name}")
if not self.portHandler.setBaudRate(self.config.baudrate):
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
# self.operating_mode = OperatingMode.UNKNOWN
# self.torque_enabled = False
# self._disable_torque()
self.operating_modes = [None for _ in range(32)]
self.torque_enabled = [None for _ in range(32)]
return True
def disconnect(self):
self.portHandler.closePort()
def set_goal_position(self, motor_id, goal_position):
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
# self._disable_torque(motor_id)
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
# if not self.torque_enabled[motor_id]:
# self._enable_torque(motor_id)
# self._enable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set position of motor {motor_id} to {goal_position}')
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
if self.operating_modes[motor_id] is not OperatingMode.PWM:
self._disable_torque(motor_id)
self.set_operating_mode(motor_id, OperatingMode.PWM)
if not self.torque_enabled[motor_id]:
self._enable_torque(motor_id)
# print(f'enabling torque')
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
)
# self._process_response(dxl_comm_result, dxl_error)
# print(f'set pwm of motor {motor_id} to {pwm_value}')
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
else:
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
def read_temperature(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
def read_velocity(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position(self, motor_id: int):
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
if pos > 2**31:
pos -= 2**32
# print(f'read position {pos} for motor {motor_id}')
return pos
def read_position_degrees(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 360
def read_position_radians(self, motor_id: int) -> float:
return (self.read_position(motor_id) / 4096) * 2 * math.pi
def read_current(self, motor_id: int):
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
if current > 2**15:
current -= 2**16
return current
def read_present_pwm(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.PWM, 2)
def read_hardware_error_status(self, motor_id: int):
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
def disconnect(self):
self.portHandler.closePort()
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
"""
sets the id of the dynamixel servo
@param old_id: current id of the servo
@param new_id: new id
@param use_broadcast_id: set ids of all connected dynamixels if True.
If False, change only servo with self.config.id
@return:
"""
if use_broadcast_id:
current_id = 254
else:
current_id = old_id
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, current_id, self.ADDR_ID, new_id
)
self._process_response(dxl_comm_result, dxl_error, old_id)
self.config.id = id
def _enable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = True
def _disable_torque(self, motor_id):
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.torque_enabled[motor_id] = False
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
if dxl_comm_result != COMM_SUCCESS:
raise ConnectionError(
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
)
elif dxl_error != 0:
print(f"dxl error {dxl_error}")
raise ConnectionError(
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
)
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self.operating_modes[motor_id] = operating_mode
def set_pwm_limit(self, motor_id: int, limit: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_velocity_limit(self, motor_id: int, velocity_limit):
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_P(self, motor_id: int, P: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_P, P
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def set_I(self, motor_id: int, I: int):
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
self.portHandler, motor_id, self.POSITION_I, I
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def read_home_offset(self, motor_id: int):
self._disable_torque(motor_id)
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
# ReadAttribute.HOMING_OFFSET.value, home_position)
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
# self._process_response(dxl_comm_result, dxl_error)
self._enable_torque(motor_id)
return home_offset
def set_home_offset(self, motor_id: int, home_position: int):
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
self._enable_torque(motor_id)
def set_baudrate(self, motor_id: int, baudrate):
# translate baudrate into dynamixel baudrate setting id
if baudrate == 57600:
baudrate_id = 1
elif baudrate == 1_000_000:
baudrate_id = 3
elif baudrate == 2_000_000:
baudrate_id = 4
elif baudrate == 3_000_000:
baudrate_id = 5
elif baudrate == 4_000_000:
baudrate_id = 6
else:
raise Exception("baudrate not implemented")
self._disable_torque(motor_id)
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
)
self._process_response(dxl_comm_result, dxl_error, motor_id)
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
try:
if num_bytes == 1:
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 2:
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
self.portHandler, motor_id, attribute.value
)
elif num_bytes == 4:
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
self.portHandler, motor_id, attribute.value
)
except Exception:
if tries == 0:
raise Exception
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
if dxl_comm_result != COMM_SUCCESS:
if tries <= 1:
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
else:
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
time.sleep(0.02)
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
if tries == 0 and dxl_error != 128:
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
else:
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
return value
def set_home_position(self, motor_id: int):
print(f"setting home position for motor {motor_id}")
self.set_home_offset(motor_id, 0)
current_position = self.read_position(motor_id)
print(f"position before {current_position}")
self.set_home_offset(motor_id, -current_position)
# dynamixel.set_home_offset(motor_id, -4096)
# dynamixel.set_home_offset(motor_id, -4294964109)
current_position = self.read_position(motor_id)
# print(f'signed position {current_position - 2** 32}')
print(f"position after {current_position}")
if __name__ == "__main__":
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
motor_id = 1
pos = dynamixel.read_position(motor_id)
for i in range(10):
s = time.monotonic()
pos = dynamixel.read_position(motor_id)
delta = time.monotonic() - s
print(f"read position took {delta}")
print(f"position {pos}")

View File

@@ -0,0 +1,192 @@
import time
from unittest.mock import MagicMock
import cv2
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from .dynamixel import pos2pwm, pwm2pos
from .robot import Robot
FPS = 30
CAMERAS_SHAPES = {
"images.high": (480, 640, 3),
"images.low": (480, 640, 3),
}
CAMERAS_PORTS = {
"images.high": "/dev/video6",
"images.low": "/dev/video0",
}
LEADER_PORT = "/dev/ttyACM1"
FOLLOWER_PORT = "/dev/ttyACM0"
MockRobot = MagicMock()
MockRobot.read_position = MagicMock()
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
MockCamera = MagicMock()
MockCamera.isOpened = MagicMock(return_value=True)
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
def capture_image(cam, cam_width, cam_height):
# Capture a single frame
_, frame = cam.read()
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# # Define your crop coordinates (top left corner and bottom right corner)
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
# # Crop the image
# image = image[y1:y2, x1:x2]
# Resize the image
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
return image
class RealEnv(gym.Env):
metadata = {}
def __init__(
self,
record: bool = False,
num_joints: int = 6,
cameras_shapes: dict = CAMERAS_SHAPES,
cameras_ports: dict = CAMERAS_PORTS,
follower_port: str = FOLLOWER_PORT,
leader_port: str = LEADER_PORT,
warmup_steps: int = 100,
trigger_torque=70,
fps: int = FPS,
fps_tolerance: float = 0.1,
mock: bool = False,
):
self.num_joints = num_joints
self.cameras_shapes = cameras_shapes
self.cameras_ports = cameras_ports
self.warmup_steps = warmup_steps
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
self.follower_port = follower_port
self.leader_port = leader_port
self.record = record
self.fps = fps
self.fps_tolerance = fps_tolerance
# Initialize the robot
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
if self.record:
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
self.leader.set_trigger_torque(trigger_torque)
# Initialize the cameras - sorted by camera names
self.cameras = {}
for cn, p in sorted(self.cameras_ports.items()):
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
if not self.cameras[cn].isOpened():
raise OSError(
f"Cannot open camera port {p} for {cn}."
f" Make sure the camera is connected and the port is correct."
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
)
# Specify gym action and observation spaces
observation_space = {}
if self.num_joints > 0:
observation_space["agent_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.record:
observation_space["leader_pos"] = spaces.Box(
low=-1000.0,
high=1000.0,
shape=(num_joints,),
dtype=np.float64,
)
if self.cameras_shapes:
for cn, hwc_shape in self.cameras_shapes.items():
# Assumes images are unsigned int8 in [0,255]
observation_space[cn] = spaces.Box(
low=0,
high=255,
# height x width x channels (e.g. 480 x 640 x 3)
shape=hwc_shape,
dtype=np.uint8,
)
self.observation_space = spaces.Dict(observation_space)
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
self._observation = {}
self._terminated = False
self.timestamps = []
def _get_obs(self):
qpos = self.follower.read_position()
self._observation["agent_pos"] = pwm2pos(qpos)
for cn, c in self.cameras.items():
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
if self.record:
action = self.leader.read_position()
self._observation["leader_pos"] = pwm2pos(action)
def reset(self, seed: int | None = None):
# Reset the robot and sync the leader and follower if we are recording
for _ in range(self.warmup_steps):
self._get_obs()
if self.record:
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
self._terminated = False
info = {}
self.timestamps = []
return self._observation, info
def step(self, action: np.ndarray = None):
if self.timestamps:
# wait the right amount of time to stay at the desired fps
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
self.timestamps.append(time.time())
# Get the observation
self._get_obs()
if self.record:
# Teleoperate the leader
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
else:
# Apply the action to the follower
self.follower.set_goal_pos(pos2pwm(action))
reward = 0
terminated = truncated = self._terminated
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
# Check if we are able to keep up with the desired fps
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
self.fps - self.fps_tolerance
):
print(
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
f" at frame {len(self.timestamps)}"
)
info["fps_error"] = True
return self._observation, reward, terminated, truncated, info
def render(self): ...
def close(self):
self.follower._disable_torque()
if self.record:
self.leader._disable_torque()

View File

@@ -0,0 +1,173 @@
# ruff: noqa
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
Class to control the robot using dynamixel servos.
"""
from enum import Enum, auto
from typing import Union
import numpy as np
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
from .dynamixel import Dynamixel, OperatingMode, ReadAttribute
class MotorControlType(Enum):
PWM = auto()
POSITION_CONTROL = auto()
DISABLED = auto()
UNKNOWN = auto()
class Robot:
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
self.servo_ids = servo_ids
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
self._init_motors()
def _init_motors(self):
self.position_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
)
for id in self.servo_ids:
self.position_reader.addParam(id)
self.velocity_reader = GroupSyncRead(
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
)
for id in self.servo_ids:
self.velocity_reader.addParam(id)
self.pos_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
)
for id in self.servo_ids:
self.pos_writer.addParam(id, [2048])
self.pwm_writer = GroupSyncWrite(
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
)
for id in self.servo_ids:
self.pwm_writer.addParam(id, [2048])
# self._disable_torque()
self.motor_control_state = MotorControlType.DISABLED
def read_position(self, tries=2):
"""
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
:param tries: maximum number of tries to read the position
:return: list of joint positions in range [0, 4096]
"""
result = self.position_reader.txRxPacket()
if result != 0:
if tries > 0:
return self.read_position(tries=tries - 1)
else:
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
positions = []
for id in self.servo_ids:
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
if position > 2**31:
position -= 2**32
positions.append(position)
return np.array(positions)
def read_velocity(self):
"""
Reads the joint velocities of the robot.
:return: list of joint velocities,
"""
self.velocity_reader.txRxPacket()
velocties = []
for id in self.servo_ids:
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
if velocity > 2**31:
velocity -= 2**32
velocties.append(velocity)
return np.array(velocties)
def set_goal_pos(self, action):
"""
:param action: list or numpy array of target joint positions in range [0, 4096[
"""
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
self._set_position_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
DXL_LOBYTE(DXL_HIWORD(action[i])),
DXL_HIBYTE(DXL_HIWORD(action[i])),
]
self.pos_writer.changeParam(motor_id, data_write)
self.pos_writer.txPacket()
def set_pwm(self, action):
"""
Sets the pwm values for the servos.
:param action: list or numpy array of pwm values in range [0, 885]
"""
if self.motor_control_state is not MotorControlType.PWM:
self._set_pwm_control()
for i, motor_id in enumerate(self.servo_ids):
data_write = [
DXL_LOBYTE(DXL_LOWORD(action[i])),
DXL_HIBYTE(DXL_LOWORD(action[i])),
]
self.pwm_writer.changeParam(motor_id, data_write)
self.pwm_writer.txPacket()
def set_trigger_torque(self, torque: int):
"""
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
"""
self.dynamixel._enable_torque(self.servo_ids[-1])
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
"""
Limits the pwm values for the servos in for position control
@param limit: 0 ~ 885
@return:
"""
if isinstance(limit, int):
limits = [
limit,
] * 5
else:
limits = limit
self._disable_torque()
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
self.dynamixel.set_pwm_limit(motor_id, limit)
self._enable_torque()
def _disable_torque(self):
print(f"disabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._disable_torque(motor_id)
def _enable_torque(self):
print(f"enabling torque for servos {self.servo_ids}")
for motor_id in self.servo_ids:
self.dynamixel._enable_torque(motor_id)
def _set_pwm_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
self._enable_torque()
self.motor_control_state = MotorControlType.PWM
def _set_position_control(self):
self._disable_torque()
for motor_id in self.servo_ids:
# TODO(rcadene): redesign
if motor_id == 9:
self.dynamixel.set_operating_mode(9, OperatingMode.CURRENT_CONTROLLED_POSITION)
else:
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
self._enable_torque()
self.motor_control_state = MotorControlType.POSITION_CONTROL

View File

@@ -0,0 +1,237 @@
"""This script demonstrates how to record a LeRobot dataset of training data
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
"""
import argparse
import copy
import os
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym
import numpy as np
import torch
from datasets import Dataset, Features, Sequence, Value
from omegaconf import OmegaConf
from tqdm import tqdm
from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
# parse the repo_id name via command line
parser = argparse.ArgumentParser()
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
parser.add_argument("--num-episodes", type=int, default=2)
parser.add_argument("--num-frames", type=int, default=400)
parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--keep-last", action="store_true")
parser.add_argument("--data_dir", type=str, default=None)
parser.add_argument("--push-to-hub", action="store_true")
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
parser.add_argument(
"--fps_tolerance",
type=float,
default=0.5,
help="Tolerance in fps for the recording before dropping episodes.",
)
parser.add_argument(
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
)
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
parser.add_argument("--mock_robot", action="store_true")
args = parser.parse_args()
repo_id = args.repo_id
num_episodes = args.num_episodes
num_frames = args.num_frames
revision = args.revision
fps = args.fps
fps_tolerance = args.fps_tolerance
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
# During data collection, frames are stored as png images in `images_dir`
images_dir = out_data / "images"
# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
videos_dir = out_data / "videos"
meta_data_dir = out_data / "meta_data"
gym_config = None
if args.config is not None:
gym_config = OmegaConf.load(args.config)
# Create image and video directories
if not os.path.exists(images_dir):
os.makedirs(images_dir, exist_ok=True)
if not os.path.exists(videos_dir):
os.makedirs(videos_dir, exist_ok=True)
if __name__ == "__main__":
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
gym_handle = "gym_real_world/RealEnv-v0"
gym_kwargs = {}
if gym_config is not None:
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
env = gym.make(
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
)
ep_dicts = []
episode_data_index = {"from": [], "to": []}
ep_fps = []
id_from = 0
id_to = 0
os.system('spd-say "gym environment created"')
ep_idx = 0
while ep_idx < num_episodes:
# bring the follower to the leader and start camera
env.reset()
os.system(f'spd-say "go {ep_idx}"')
# init buffers
obs_replay = {k: [] for k in env.observation_space}
drop_episode = False
timestamps = []
for _ in tqdm(range(num_frames)):
# Apply the next action
observation, _, _, _, info = env.step(action=None)
# images_stacked = np.hstack(list(observation['pixels'].values()))
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
# cv2.imshow('frame', images_stacked)
if info["fps_error"]:
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
drop_episode = True
break
# store data
for key in observation:
obs_replay[key].append(copy.deepcopy(observation[key]))
timestamps.append(info["timestamp"])
# if cv2.waitKey(1) & 0xFF == ord('q'):
# break
os.system('spd-say "stop"')
if not drop_episode:
os.system(f'spd-say "saving episode {ep_idx}"')
ep_dict = {}
# store images in png and create the video
for img_key in env.cameras:
save_images_concurrently(
obs_replay[img_key],
images_dir / f"{img_key}_episode_{ep_idx:06d}",
args.num_workers,
)
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
# store the reference to the video frame
ep_dict[f"observation.{img_key}"] = [
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
]
state = torch.tensor(np.array(obs_replay["agent_pos"]))
action = torch.tensor(np.array(obs_replay["leader_pos"]))
next_done = torch.zeros(num_frames, dtype=torch.bool)
next_done[-1] = True
ep_dict["observation.state"] = state
ep_dict["action"] = action
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
ep_dict["timestamp"] = torch.tensor(timestamps)
ep_dict["next.done"] = next_done
ep_fps.append(num_frames / timestamps[-1])
ep_dicts.append(ep_dict)
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
episode_data_index["from"].append(id_from)
episode_data_index["to"].append(
id_from + num_frames if args.keep_last else id_from + num_frames - 1
)
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
id_from = id_to
ep_idx += 1
env.close()
os.system('spd-say "encode video frames"')
for ep_idx in range(num_episodes):
for img_key in env.cameras:
# If necessary, we may want to encode the video
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
encode_video_frames(
images_dir / f"{img_key}_episode_{ep_idx:06d}",
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
ep_fps[ep_idx],
)
os.system('spd-say "concatenate episodes"')
data_dict = concatenate_episodes(
ep_dicts, drop_episodes_last_frame=not args.keep_last
) # Since our fps varies we are sometimes off tolerance for the last frame
features = {}
keys = [key for key in data_dict if "observation.images." in key]
for key in keys:
features[key] = VideoFrame()
features["observation.state"] = Sequence(
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
)
features["action"] = Sequence(
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
)
features["episode_index"] = Value(dtype="int64", id=None)
features["frame_index"] = Value(dtype="int64", id=None)
features["timestamp"] = Value(dtype="float32", id=None)
features["next.done"] = Value(dtype="bool", id=None)
features["index"] = Value(dtype="int64", id=None)
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
hf_dataset.set_transform(hf_transform_to_torch)
info = {
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
"video": 1,
}
os.system('spd-say "from preloaded"')
lerobot_dataset = LeRobotDataset.from_preloaded(
repo_id=repo_id,
version=revision,
hf_dataset=hf_dataset,
episode_data_index=episode_data_index,
info=info,
videos_dir=videos_dir,
)
os.system('spd-say "compute stats"')
stats = compute_stats(lerobot_dataset)
os.system('spd-say "save to disk"')
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
hf_dataset.save_to_disk(str(out_data / "train"))
save_meta_data(info, stats, episode_data_index, meta_data_dir)
if args.push_to_hub:
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
push_videos_to_hub(repo_id, videos_dir, revision="main")
push_videos_to_hub(repo_id, videos_dir, revision=revision)

View File

@@ -0,0 +1,60 @@
import argparse
import logging
from pathlib import Path
import gym_real_world # noqa: F401
import gymnasium as gym # noqa: F401
from huggingface_hub import snapshot_download
from huggingface_hub.utils._errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.eval import eval
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
try:
pretrained_policy_path = Path(
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
)
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)

View File

@@ -0,0 +1,103 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.high:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.low:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.high: [3, 480, 640]
observation.images.low: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.high: mean_std
observation.images.low: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0

View File

@@ -0,0 +1,103 @@
# @package _global_
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
#
# Example of usage for training:
# ```bash
# python lerobot/scripts/train.py \
# policy=act_real \
# env=aloha_real
# ```
seed: 1000
dataset_repo_id: ???
override_dataset_stats:
observation.images.top:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
observation.images.front:
# stats from imagenet, since we use a pretrained vision model
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
training:
offline_steps: 1000
online_steps: 0
eval_freq: -1
save_freq: 1000
log_freq: 100
save_checkpoint: true
batch_size: 8
lr: 1e-5
lr_backbone: 1e-5
weight_decay: 1e-4
grad_clip_norm: 10
online_steps_between_rollouts: 1
delta_timestamps:
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
eval:
n_episodes: 1
batch_size: 1
# See `configuration_act.py` for more details.
policy:
name: act
# Input / output structure.
n_obs_steps: 1
chunk_size: 100 # chunk_size
n_action_steps: 100
input_shapes:
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
observation.images.top: [3, 480, 640]
observation.images.front: [3, 480, 640]
observation.state: ["${env.state_dim}"]
output_shapes:
action: ["${env.action_dim}"]
# Normalization / Unnormalization
input_normalization_modes:
observation.images.top: mean_std
observation.images.front: mean_std
observation.state: mean_std
output_normalization_modes:
action: mean_std
# Architecture.
# Vision backbone.
vision_backbone: resnet18
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
replace_final_stride_with_dilation: false
# Transformer layers.
pre_norm: false
dim_model: 512
n_heads: 8
dim_feedforward: 3200
feedforward_activation: relu
n_encoder_layers: 4
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
n_decoder_layers: 1
# VAE.
use_vae: true
latent_dim: 32
n_vae_encoder_layers: 4
# Inference.
temporal_ensemble_momentum: null
# Training and loss computation.
dropout: 0.1
kl_weight: 10.0