forked from tangger/lerobot
Compare commits
20 Commits
hf-papers
...
thomwolf_2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7c030076f | ||
|
|
ddaaa9f279 | ||
|
|
ef074d7281 | ||
|
|
797f79f182 | ||
|
|
7ff93e8a51 | ||
|
|
c304474f6a | ||
|
|
e53a03ca53 | ||
|
|
86508a167f | ||
|
|
ac9bbf2cd7 | ||
|
|
cc3dcf2b89 | ||
|
|
5332fef758 | ||
|
|
a40d0cbcc7 | ||
|
|
3717639f11 | ||
|
|
562f09451e | ||
|
|
754944151c | ||
|
|
480ed50d36 | ||
|
|
ac7d6228ed | ||
|
|
a461a71277 | ||
|
|
57d3d27c78 | ||
|
|
0935e49c8a |
55
README.md
55
README.md
@@ -127,13 +127,21 @@ wandb login
|
||||
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset by executing our script from the command line:
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable
|
||||
```bash
|
||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
|
||||
It will open `rerun.io` and display the camera streams, robot states and actions, like this:
|
||||
|
||||
https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-fd46b787-b532-47e2-bb6f-fd536a55a7ed.mov?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240505%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240505T172924Z&X-Amz-Expires=300&X-Amz-Signature=d680b26c532eeaf80740f08af3320d22ad0b8a4e4da1bcc4f33142c15b509eda&X-Amz-SignedHeaders=host&actor_id=24889239&key_id=0&repo_id=748713144
|
||||
@@ -141,6 +149,51 @@ https://github-production-user-asset-6210df.s3.amazonaws.com/4681518/328035972-f
|
||||
|
||||
Our script can also visualize datasets stored on a distant server. See `python lerobot/scripts/visualize_dataset.py --help` for more instructions.
|
||||
|
||||
### The `LeRobotDataset` format
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and Pytorch dataset. For instance `dataset[0]` will retrieve a sample of the dataset observations and actions in pytorch tensors format ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that we can retrieve several frames for one sample query. By setting `delta_timestamps` to a list of delta timestamps, e.g. `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for each query, 4 images including one at -1 second before the current time step, the two others at -0.5 second and -0.2, and the final one at the current time step (0 second). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states.
|
||||
|
||||
Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects:
|
||||
|
||||
```
|
||||
dataset attributes:
|
||||
├ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example:
|
||||
│ ├ observation.images.cam_high: VideoFrame
|
||||
│ │ VideoFrame = {'path': path to a mp4 video, 'timestamp': float32 timestamp in the video}
|
||||
│ ├ observation.state: List of float32: position of an arm joints (for instance)
|
||||
│ ... (more observations)
|
||||
│ ├ action: List of float32
|
||||
│ ├ episode_index: int64: index of the episode for this sample
|
||||
│ ├ frame_index: int64: index of the frame for this sample in the episode ; starts at 0 for each episode
|
||||
│ ├ timestamp: float32: timestamp in the episode
|
||||
│ ├ next.done: bool: indicates the end of en episode ; True for the last frame in each episode
|
||||
│ └ index: int64: general index in the whole dataset
|
||||
├ episode_data_index: contains 2 tensors with the start and end indices of each episode
|
||||
│ ├ from: 1D int64 tensor of first frame index for each episode: shape (num episodes,) starts with 0
|
||||
│ └ to: 1D int64 tensor of last frame index for each episode: shape (num episodes,)
|
||||
├ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance
|
||||
│ ├ observation.images.cam_high: {'max': tensor with same number of dimensions (e.g. `(c, 1, 1)` for images, `(c,)` for states), etc.}
|
||||
│ ...
|
||||
├ info: a dictionary of metadata on the dataset
|
||||
│ ├ fps: float - frame per second the dataset is recorded/synchronized to
|
||||
│ └ video: bool - indicates if frames are encoded in mp4 video files to save space or stored as png files
|
||||
├ videos_dir: path to where the mp4 videos or png images are stored/accessed
|
||||
└ camera_keys: List of string: the keys to access camera features in the item returned by the dataset (e.g. `["observation.images.cam_high", ...]`)
|
||||
```
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space or png files
|
||||
- episode_data_index saved using `safetensor` tensor serializtion format
|
||||
- stats saved using `safetensor` tensor serializtion format
|
||||
- info are saved using JSON
|
||||
|
||||
Dataset can uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to you root dataset folder as illustrated in the above section on dataset visualization.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
|
||||
|
||||
89
examples/real_robot_example/README.md
Normal file
89
examples/real_robot_example/README.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Using `lerobot` on a real world arm
|
||||
|
||||
|
||||
In this example, we'll be using `lerobot` on a real world arm to:
|
||||
- record a dataset in the `lerobot` format
|
||||
- (soon) train a policy on it
|
||||
- (soon) run the policy in the real-world
|
||||
|
||||
## Which robotic arm to use
|
||||
|
||||
In this example we're using the [open-source low-cost arm from Alexander Koch](https://github.com/AlexanderKoch-Koch/low_cost_robot) in the specific setup of:
|
||||
- having 6 servos per arm, i.e. using the elbow-to-wrist extension
|
||||
- adding two cameras around it, one on top and one in the front
|
||||
- having a teleoperation arm as well (build the leader and the follower arms in A. Koch repo, both with elbow-to-wrist extensions)
|
||||
|
||||
I'm using these cameras (but the setup should not be sensitive to the exact cameras you're using):
|
||||
- C922 Pro Stream Webcam
|
||||
- Intel(R) RealSense D455 (using only the RGB input)
|
||||
|
||||
|
||||
In general, this example should be very easily extendable to any type of arm using Dynamixel servos with at least one camera by changing a couple of configuration in the gym env.
|
||||
|
||||
## Install the example
|
||||
|
||||
Follow these steps:
|
||||
- install `lerobot`
|
||||
- install the Dynamixel-sdk: ` pip install dynamixel-sdk`
|
||||
|
||||
## Usage
|
||||
|
||||
### 0 - record examples
|
||||
|
||||
Run the `record_training_data.py` example, selecting the duration and number of episodes you want to record, e.g.
|
||||
```
|
||||
DATA_DIR='./data' python record_training_data.py \
|
||||
--repo-id=thomwolf/blue_red_sort \
|
||||
--num-episodes=50 \
|
||||
--num-frames=400
|
||||
```
|
||||
|
||||
TODO:
|
||||
- various length episodes
|
||||
- being able to drop episodes
|
||||
- checking uploading to the hub
|
||||
|
||||
### 1 - visualize the dataset
|
||||
|
||||
Use the standard dataset visualization script pointing it to the right folder:
|
||||
```
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id thomwolf/blue_red_sort \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
### 2 - Train a policy
|
||||
|
||||
From the example directory let's run this command to train a model using ACT
|
||||
|
||||
```
|
||||
DATA_DIR='./data' python ../../lerobot/scripts/train.py \
|
||||
device=cuda \
|
||||
hydra.searchpath=[file://./train_config/] \
|
||||
hydra.run.dir=./outputs/train/blue_red_sort \
|
||||
dataset_repo_id=thomwolf/blue_red_sort \
|
||||
env=gym_real_world \
|
||||
policy=act_real_world \
|
||||
wandb.enable=false
|
||||
```
|
||||
|
||||
### 3 - Evaluate the policy in the real world
|
||||
|
||||
From the example directory let's run this command to evaluate our policy.
|
||||
The configuration for running the policy is in the checkpoint of the model.
|
||||
You can override parameters as follow:
|
||||
|
||||
```
|
||||
python run_policy.py \
|
||||
-p ./outputs/train/blue_red_sort/checkpoints/last/pretrained_model/
|
||||
env.episode_length=1000
|
||||
```
|
||||
|
||||
|
||||
## Convert a hdf5 dataset recorded with the original ACT repo
|
||||
|
||||
You can convert a dataset from the raw data format of HDF5 files like in: https://github.com/tonyzhaozh/act with the following command:
|
||||
|
||||
```
|
||||
python ./lerobot/scripts/push_dataset_to_hub.py
|
||||
```
|
||||
@@ -0,0 +1,840 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from safetensors.torch import load_file, save_file\n",
|
||||
"from pprint import pprint"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"original_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/policy_last.ckpt\"\n",
|
||||
"converted_ckpt_path = \"/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/model.safetensors\"\n",
|
||||
"\n",
|
||||
"comparison_main_path = \"/home/thomwolf/Documents/Github/lerobot/examples/real_robot_example/outputs/train/blue_red_debug_no_masking/checkpoints/last/pretrained_model/\"\n",
|
||||
"comparison_safetensor_path = comparison_main_path + \"model.safetensors\"\n",
|
||||
"comparison_config_json_path = comparison_main_path + \"config.json\"\n",
|
||||
"comparison_config_yaml_path = comparison_main_path + \"config.yaml\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a = torch.load(original_ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"b = load_file(comparison_safetensor_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['model.action_head.bias',\n",
|
||||
" 'model.action_head.weight',\n",
|
||||
" 'model.backbone.bn1.bias',\n",
|
||||
" 'model.backbone.bn1.running_mean',\n",
|
||||
" 'model.backbone.bn1.running_var',\n",
|
||||
" 'model.backbone.bn1.weight',\n",
|
||||
" 'model.backbone.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer1.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer1.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer1.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer1.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer1.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer1.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer1.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer1.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer1.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer1.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer1.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer1.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer1.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer1.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer1.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer1.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer1.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer1.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer2.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer2.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer2.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer2.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer2.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer2.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer2.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer2.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer2.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer2.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer2.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer2.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer2.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer2.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer2.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer2.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer2.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer2.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer2.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer2.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer3.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer3.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer3.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer3.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer3.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer3.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer3.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer3.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer3.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer3.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer3.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer3.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer3.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer3.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer3.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer3.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer3.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer3.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer3.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer3.1.conv2.weight',\n",
|
||||
" 'model.backbone.layer4.0.bn1.bias',\n",
|
||||
" 'model.backbone.layer4.0.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.bn1.running_var',\n",
|
||||
" 'model.backbone.layer4.0.bn1.weight',\n",
|
||||
" 'model.backbone.layer4.0.bn2.bias',\n",
|
||||
" 'model.backbone.layer4.0.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.bn2.running_var',\n",
|
||||
" 'model.backbone.layer4.0.bn2.weight',\n",
|
||||
" 'model.backbone.layer4.0.conv1.weight',\n",
|
||||
" 'model.backbone.layer4.0.conv2.weight',\n",
|
||||
" 'model.backbone.layer4.0.downsample.0.weight',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.bias',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.running_var',\n",
|
||||
" 'model.backbone.layer4.0.downsample.1.weight',\n",
|
||||
" 'model.backbone.layer4.1.bn1.bias',\n",
|
||||
" 'model.backbone.layer4.1.bn1.running_mean',\n",
|
||||
" 'model.backbone.layer4.1.bn1.running_var',\n",
|
||||
" 'model.backbone.layer4.1.bn1.weight',\n",
|
||||
" 'model.backbone.layer4.1.bn2.bias',\n",
|
||||
" 'model.backbone.layer4.1.bn2.running_mean',\n",
|
||||
" 'model.backbone.layer4.1.bn2.running_var',\n",
|
||||
" 'model.backbone.layer4.1.bn2.weight',\n",
|
||||
" 'model.backbone.layer4.1.conv1.weight',\n",
|
||||
" 'model.backbone.layer4.1.conv2.weight',\n",
|
||||
" 'model.decoder.layers.0.linear1.bias',\n",
|
||||
" 'model.decoder.layers.0.linear1.weight',\n",
|
||||
" 'model.decoder.layers.0.linear2.bias',\n",
|
||||
" 'model.decoder.layers.0.linear2.weight',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.decoder.layers.0.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.decoder.layers.0.norm1.bias',\n",
|
||||
" 'model.decoder.layers.0.norm1.weight',\n",
|
||||
" 'model.decoder.layers.0.norm2.bias',\n",
|
||||
" 'model.decoder.layers.0.norm2.weight',\n",
|
||||
" 'model.decoder.layers.0.norm3.bias',\n",
|
||||
" 'model.decoder.layers.0.norm3.weight',\n",
|
||||
" 'model.decoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.decoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.decoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.decoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.decoder_pos_embed.weight',\n",
|
||||
" 'model.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder_img_feat_input_proj.bias',\n",
|
||||
" 'model.encoder_img_feat_input_proj.weight',\n",
|
||||
" 'model.encoder_latent_input_proj.bias',\n",
|
||||
" 'model.encoder_latent_input_proj.weight',\n",
|
||||
" 'model.encoder_robot_and_latent_pos_embed.weight',\n",
|
||||
" 'model.encoder_robot_state_input_proj.bias',\n",
|
||||
" 'model.encoder_robot_state_input_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.vae_encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.vae_encoder_action_input_proj.bias',\n",
|
||||
" 'model.vae_encoder_action_input_proj.weight',\n",
|
||||
" 'model.vae_encoder_cls_embed.weight',\n",
|
||||
" 'model.vae_encoder_latent_output_proj.bias',\n",
|
||||
" 'model.vae_encoder_latent_output_proj.weight',\n",
|
||||
" 'model.vae_encoder_pos_enc',\n",
|
||||
" 'model.vae_encoder_robot_state_input_proj.bias',\n",
|
||||
" 'model.vae_encoder_robot_state_input_proj.weight',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_front.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_front.std',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_top.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_images_top.std',\n",
|
||||
" 'normalize_inputs.buffer_observation_state.mean',\n",
|
||||
" 'normalize_inputs.buffer_observation_state.std',\n",
|
||||
" 'normalize_targets.buffer_action.mean',\n",
|
||||
" 'normalize_targets.buffer_action.std',\n",
|
||||
" 'unnormalize_outputs.buffer_action.mean',\n",
|
||||
" 'unnormalize_outputs.buffer_action.std']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dest = list(b.keys())\n",
|
||||
"pprint(dest)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['model.pos_table',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.transformer.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.0.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.1.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.2.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.3.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.4.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.5.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.self_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.in_proj_bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.multihead_attn.out_proj.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.linear2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm1.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm1.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm2.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm2.bias',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm3.weight',\n",
|
||||
" 'model.transformer.decoder.layers.6.norm3.bias',\n",
|
||||
" 'model.transformer.decoder.norm.weight',\n",
|
||||
" 'model.transformer.decoder.norm.bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.0.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.0.linear1.weight',\n",
|
||||
" 'model.encoder.layers.0.linear1.bias',\n",
|
||||
" 'model.encoder.layers.0.linear2.weight',\n",
|
||||
" 'model.encoder.layers.0.linear2.bias',\n",
|
||||
" 'model.encoder.layers.0.norm1.weight',\n",
|
||||
" 'model.encoder.layers.0.norm1.bias',\n",
|
||||
" 'model.encoder.layers.0.norm2.weight',\n",
|
||||
" 'model.encoder.layers.0.norm2.bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.1.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.1.linear1.weight',\n",
|
||||
" 'model.encoder.layers.1.linear1.bias',\n",
|
||||
" 'model.encoder.layers.1.linear2.weight',\n",
|
||||
" 'model.encoder.layers.1.linear2.bias',\n",
|
||||
" 'model.encoder.layers.1.norm1.weight',\n",
|
||||
" 'model.encoder.layers.1.norm1.bias',\n",
|
||||
" 'model.encoder.layers.1.norm2.weight',\n",
|
||||
" 'model.encoder.layers.1.norm2.bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.2.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.2.linear1.weight',\n",
|
||||
" 'model.encoder.layers.2.linear1.bias',\n",
|
||||
" 'model.encoder.layers.2.linear2.weight',\n",
|
||||
" 'model.encoder.layers.2.linear2.bias',\n",
|
||||
" 'model.encoder.layers.2.norm1.weight',\n",
|
||||
" 'model.encoder.layers.2.norm1.bias',\n",
|
||||
" 'model.encoder.layers.2.norm2.weight',\n",
|
||||
" 'model.encoder.layers.2.norm2.bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.in_proj_bias',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.weight',\n",
|
||||
" 'model.encoder.layers.3.self_attn.out_proj.bias',\n",
|
||||
" 'model.encoder.layers.3.linear1.weight',\n",
|
||||
" 'model.encoder.layers.3.linear1.bias',\n",
|
||||
" 'model.encoder.layers.3.linear2.weight',\n",
|
||||
" 'model.encoder.layers.3.linear2.bias',\n",
|
||||
" 'model.encoder.layers.3.norm1.weight',\n",
|
||||
" 'model.encoder.layers.3.norm1.bias',\n",
|
||||
" 'model.encoder.layers.3.norm2.weight',\n",
|
||||
" 'model.encoder.layers.3.norm2.bias',\n",
|
||||
" 'model.action_head.weight',\n",
|
||||
" 'model.action_head.bias',\n",
|
||||
" 'model.is_pad_head.weight',\n",
|
||||
" 'model.is_pad_head.bias',\n",
|
||||
" 'model.query_embed.weight',\n",
|
||||
" 'model.input_proj.weight',\n",
|
||||
" 'model.input_proj.bias',\n",
|
||||
" 'model.backbones.0.0.body.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer1.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer2.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer3.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.bn2.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.0.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.0.downsample.1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.conv1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn1.num_batches_tracked',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.conv2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.weight',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.bias',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.running_mean',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.running_var',\n",
|
||||
" 'model.backbones.0.0.body.layer4.1.bn2.num_batches_tracked',\n",
|
||||
" 'model.input_proj_robot_state.weight',\n",
|
||||
" 'model.input_proj_robot_state.bias',\n",
|
||||
" 'model.cls_embed.weight',\n",
|
||||
" 'model.encoder_action_proj.weight',\n",
|
||||
" 'model.encoder_action_proj.bias',\n",
|
||||
" 'model.encoder_joint_proj.weight',\n",
|
||||
" 'model.encoder_joint_proj.bias',\n",
|
||||
" 'model.latent_proj.weight',\n",
|
||||
" 'model.latent_proj.bias',\n",
|
||||
" 'model.latent_out_proj.weight',\n",
|
||||
" 'model.latent_out_proj.bias',\n",
|
||||
" 'model.additional_pos_embed.weight']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"orig = list(a.keys())\n",
|
||||
"pprint(orig)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a = torch.load(original_ckpt_path)\n",
|
||||
"\n",
|
||||
"to_remove_startswith = ['model.transformer.decoder.layers.1.',\n",
|
||||
" 'model.transformer.decoder.layers.2.',\n",
|
||||
" 'model.transformer.decoder.layers.3.',\n",
|
||||
" 'model.transformer.decoder.layers.4.',\n",
|
||||
" 'model.transformer.decoder.layers.5.',\n",
|
||||
" 'model.transformer.decoder.layers.6.',\n",
|
||||
" 'model.transformer.decoder.norm.',\n",
|
||||
" 'model.is_pad_head']\n",
|
||||
"\n",
|
||||
"to_remove_in = ['num_batches_tracked',]\n",
|
||||
"\n",
|
||||
"conv = {}\n",
|
||||
"\n",
|
||||
"keys = list(a.keys())\n",
|
||||
"for k in keys:\n",
|
||||
" if any(k.startswith(tr) for tr in to_remove_startswith):\n",
|
||||
" a.pop(k)\n",
|
||||
" continue\n",
|
||||
" if any(tr in k for tr in to_remove_in):\n",
|
||||
" a.pop(k)\n",
|
||||
" continue\n",
|
||||
" if k.startswith('model.transformer.encoder.layers.'):\n",
|
||||
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
|
||||
" if k.startswith('model.transformer.decoder.layers.0.'):\n",
|
||||
" conv[k.replace('transformer.', '')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder.layers.'):\n",
|
||||
" conv[k.replace('encoder.', 'vae_encoder.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.action_head.'):\n",
|
||||
" conv[k] = a.pop(k)\n",
|
||||
" if k.startswith('model.pos_table'):\n",
|
||||
" conv[k.replace('pos_table', 'vae_encoder_pos_enc')] = a.pop(k)\n",
|
||||
" if k.startswith('model.query_embed.'):\n",
|
||||
" conv[k.replace('query_embed', 'decoder_pos_embed')] = a.pop(k)\n",
|
||||
" if k.startswith('model.input_proj.'):\n",
|
||||
" conv[k.replace('input_proj.', 'encoder_img_feat_input_proj.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.input_proj_robot_state.'):\n",
|
||||
" conv[k.replace('input_proj_robot_state.', 'encoder_robot_state_input_proj.')] = a.pop(k)\n",
|
||||
" if k.startswith('model.backbones.0.0.body.'):\n",
|
||||
" conv[k.replace('backbones.0.0.body', 'backbone')] = a.pop(k)\n",
|
||||
" if k.startswith('model.cls_embed.'):\n",
|
||||
" conv[k.replace('cls_embed', 'vae_encoder_cls_embed')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder_action_proj.'):\n",
|
||||
" conv[k.replace('encoder_action_proj', 'vae_encoder_action_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.encoder_joint_proj.'):\n",
|
||||
" conv[k.replace('encoder_joint_proj', 'vae_encoder_robot_state_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.latent_proj.'):\n",
|
||||
" conv[k.replace('latent_proj', 'vae_encoder_latent_output_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.latent_out_proj.'):\n",
|
||||
" conv[k.replace('latent_out_proj', 'encoder_latent_input_proj')] = a.pop(k)\n",
|
||||
" if k.startswith('model.additional_pos_embed.'):\n",
|
||||
" conv[k.replace('additional_pos_embed', 'encoder_robot_and_latent_pos_embed')] = a.pop(k)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"OrderedDict()"
|
||||
]
|
||||
},
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for k, v in conv.items():\n",
|
||||
" assert b[k].shape == v.shape\n",
|
||||
" b[k] = v"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"save_file(b, converted_ckpt_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'/home/thomwolf/Documents/Github/ACT/checkpoints/blue_red_sort/config.yaml'"
|
||||
]
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Now also copy the config files\n",
|
||||
"import shutil\n",
|
||||
"shutil.copy(comparison_config_json_path, converted_ckpt_path.replace('model.safetensors', 'config.json'))\n",
|
||||
"shutil.copy(comparison_config_yaml_path, converted_ckpt_path.replace('model.safetensors', 'config.yaml'))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lerobot",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
8
examples/real_robot_example/gym_real_world/__init__.py
Normal file
8
examples/real_robot_example/gym_real_world/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from gymnasium.envs.registration import register
|
||||
|
||||
register(
|
||||
id="gym_real_world/RealEnv-v0",
|
||||
entry_point="gym_real_world.gym_environment:RealEnv",
|
||||
max_episode_steps=300,
|
||||
nondeterministic=True,
|
||||
)
|
||||
363
examples/real_robot_example/gym_real_world/dynamixel.py
Normal file
363
examples/real_robot_example/gym_real_world/dynamixel.py
Normal file
@@ -0,0 +1,363 @@
|
||||
# ruff: noqa
|
||||
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
|
||||
Dynamixel class to control the dynamixel servos
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
from dynamixel_sdk import * # Uses Dynamixel SDK library
|
||||
|
||||
|
||||
def pos2pwm(pos: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pos: numpy array of joint positions in range [-pi, pi]
|
||||
:return: numpy array of pwm values in range [0, 4096]
|
||||
"""
|
||||
return ((pos / 3.14 + 1.0) * 2048).astype(np.int64)
|
||||
|
||||
|
||||
def pwm2pos(pwm: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pwm: numpy array of pwm values in range [0, 4096]
|
||||
:return: numpy array of joint positions in range [-pi, pi]
|
||||
"""
|
||||
return (pwm / 2048 - 1) * 3.14
|
||||
|
||||
|
||||
def pwm2vel(pwm: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param pwm: numpy array of pwm/s joint velocities
|
||||
:return: numpy array of rad/s joint velocities
|
||||
"""
|
||||
return pwm * 3.14 / 2048
|
||||
|
||||
|
||||
def vel2pwm(vel: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
:param vel: numpy array of rad/s joint velocities
|
||||
:return: numpy array of pwm/s joint velocities
|
||||
"""
|
||||
return (vel * 2048 / 3.14).astype(np.int64)
|
||||
|
||||
|
||||
class ReadAttribute(enum.Enum):
|
||||
TEMPERATURE = 146
|
||||
VOLTAGE = 145
|
||||
VELOCITY = 128
|
||||
POSITION = 132
|
||||
CURRENT = 126
|
||||
PWM = 124
|
||||
HARDWARE_ERROR_STATUS = 70
|
||||
HOMING_OFFSET = 20
|
||||
BAUDRATE = 8
|
||||
|
||||
|
||||
class OperatingMode(enum.Enum):
|
||||
VELOCITY = 1
|
||||
POSITION = 3
|
||||
CURRENT_CONTROLLED_POSITION = 5
|
||||
PWM = 16
|
||||
UNKNOWN = -1
|
||||
|
||||
|
||||
class Dynamixel:
|
||||
ADDR_TORQUE_ENABLE = 64
|
||||
ADDR_GOAL_POSITION = 116
|
||||
ADDR_VELOCITY_LIMIT = 44
|
||||
ADDR_GOAL_PWM = 100
|
||||
OPERATING_MODE_ADDR = 11
|
||||
POSITION_I = 82
|
||||
POSITION_P = 84
|
||||
ADDR_ID = 7
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
def instantiate(self):
|
||||
return Dynamixel(self)
|
||||
|
||||
baudrate: int = 57600
|
||||
protocol_version: float = 2.0
|
||||
device_name: str = "" # /dev/tty.usbserial-1120'
|
||||
dynamixel_id: int = 1
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
if self.config.device_name == "":
|
||||
for port_name in os.listdir("/dev"):
|
||||
if "ttyUSB" in port_name or "ttyACM" in port_name:
|
||||
self.config.device_name = "/dev/" + port_name
|
||||
print(f"using device {self.config.device_name}")
|
||||
self.portHandler = PortHandler(self.config.device_name)
|
||||
# self.portHandler.LA
|
||||
self.packetHandler = PacketHandler(self.config.protocol_version)
|
||||
if not self.portHandler.openPort():
|
||||
raise Exception(f"Failed to open port {self.config.device_name}")
|
||||
|
||||
if not self.portHandler.setBaudRate(self.config.baudrate):
|
||||
raise Exception(f"failed to set baudrate to {self.config.baudrate}")
|
||||
|
||||
# self.operating_mode = OperatingMode.UNKNOWN
|
||||
# self.torque_enabled = False
|
||||
# self._disable_torque()
|
||||
|
||||
self.operating_modes = [None for _ in range(32)]
|
||||
self.torque_enabled = [None for _ in range(32)]
|
||||
return True
|
||||
|
||||
def disconnect(self):
|
||||
self.portHandler.closePort()
|
||||
|
||||
def set_goal_position(self, motor_id, goal_position):
|
||||
# if self.operating_modes[motor_id] is not OperatingMode.POSITION:
|
||||
# self._disable_torque(motor_id)
|
||||
# self.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||
|
||||
# if not self.torque_enabled[motor_id]:
|
||||
# self._enable_torque(motor_id)
|
||||
|
||||
# self._enable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_GOAL_POSITION, goal_position
|
||||
)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
# print(f'set position of motor {motor_id} to {goal_position}')
|
||||
|
||||
def set_pwm_value(self, motor_id: int, pwm_value, tries=3):
|
||||
if self.operating_modes[motor_id] is not OperatingMode.PWM:
|
||||
self._disable_torque(motor_id)
|
||||
self.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||
|
||||
if not self.torque_enabled[motor_id]:
|
||||
self._enable_torque(motor_id)
|
||||
# print(f'enabling torque')
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_GOAL_PWM, pwm_value
|
||||
)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
# print(f'set pwm of motor {motor_id} to {pwm_value}')
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
if tries <= 1:
|
||||
raise ConnectionError(f"dxl_comm_result: {self.packetHandler.getTxRxResult(dxl_comm_result)}")
|
||||
else:
|
||||
print(f"dynamixel pwm setting failure trying again with {tries - 1} tries")
|
||||
self.set_pwm_value(motor_id, pwm_value, tries=tries - 1)
|
||||
elif dxl_error != 0:
|
||||
print(f"dxl error {dxl_error}")
|
||||
raise ConnectionError(f"dynamixel error: {self.packetHandler.getTxRxResult(dxl_error)}")
|
||||
|
||||
def read_temperature(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.TEMPERATURE, 1)
|
||||
|
||||
def read_velocity(self, motor_id: int):
|
||||
pos = self._read_value(motor_id, ReadAttribute.VELOCITY, 4)
|
||||
if pos > 2**31:
|
||||
pos -= 2**32
|
||||
# print(f'read position {pos} for motor {motor_id}')
|
||||
return pos
|
||||
|
||||
def read_position(self, motor_id: int):
|
||||
pos = self._read_value(motor_id, ReadAttribute.POSITION, 4)
|
||||
if pos > 2**31:
|
||||
pos -= 2**32
|
||||
# print(f'read position {pos} for motor {motor_id}')
|
||||
return pos
|
||||
|
||||
def read_position_degrees(self, motor_id: int) -> float:
|
||||
return (self.read_position(motor_id) / 4096) * 360
|
||||
|
||||
def read_position_radians(self, motor_id: int) -> float:
|
||||
return (self.read_position(motor_id) / 4096) * 2 * math.pi
|
||||
|
||||
def read_current(self, motor_id: int):
|
||||
current = self._read_value(motor_id, ReadAttribute.CURRENT, 2)
|
||||
if current > 2**15:
|
||||
current -= 2**16
|
||||
return current
|
||||
|
||||
def read_present_pwm(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.PWM, 2)
|
||||
|
||||
def read_hardware_error_status(self, motor_id: int):
|
||||
return self._read_value(motor_id, ReadAttribute.HARDWARE_ERROR_STATUS, 1)
|
||||
|
||||
def disconnect(self):
|
||||
self.portHandler.closePort()
|
||||
|
||||
def set_id(self, old_id, new_id, use_broadcast_id: bool = False):
|
||||
"""
|
||||
sets the id of the dynamixel servo
|
||||
@param old_id: current id of the servo
|
||||
@param new_id: new id
|
||||
@param use_broadcast_id: set ids of all connected dynamixels if True.
|
||||
If False, change only servo with self.config.id
|
||||
@return:
|
||||
"""
|
||||
if use_broadcast_id:
|
||||
current_id = 254
|
||||
else:
|
||||
current_id = old_id
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, current_id, self.ADDR_ID, new_id
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, old_id)
|
||||
self.config.id = id
|
||||
|
||||
def _enable_torque(self, motor_id):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 1
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.torque_enabled[motor_id] = True
|
||||
|
||||
def _disable_torque(self, motor_id):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_TORQUE_ENABLE, 0
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.torque_enabled[motor_id] = False
|
||||
|
||||
def _process_response(self, dxl_comm_result: int, dxl_error: int, motor_id: int):
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"dxl_comm_result for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_comm_result)}"
|
||||
)
|
||||
elif dxl_error != 0:
|
||||
print(f"dxl error {dxl_error}")
|
||||
raise ConnectionError(
|
||||
f"dynamixel error for motor {motor_id}: {self.packetHandler.getTxRxResult(dxl_error)}"
|
||||
)
|
||||
|
||||
def set_operating_mode(self, motor_id: int, operating_mode: OperatingMode):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.OPERATING_MODE_ADDR, operating_mode.value
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self.operating_modes[motor_id] = operating_mode
|
||||
|
||||
def set_pwm_limit(self, motor_id: int, limit: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(self.portHandler, motor_id, 36, limit)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_velocity_limit(self, motor_id: int, velocity_limit):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, self.ADDR_VELOCITY_LIMIT, velocity_limit
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_P(self, motor_id: int, P: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.POSITION_P, P
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def set_I(self, motor_id: int, I: int):
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write2ByteTxRx(
|
||||
self.portHandler, motor_id, self.POSITION_I, I
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def read_home_offset(self, motor_id: int):
|
||||
self._disable_torque(motor_id)
|
||||
# dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(self.portHandler, motor_id,
|
||||
# ReadAttribute.HOMING_OFFSET.value, home_position)
|
||||
home_offset = self._read_value(motor_id, ReadAttribute.HOMING_OFFSET, 4)
|
||||
# self._process_response(dxl_comm_result, dxl_error)
|
||||
self._enable_torque(motor_id)
|
||||
return home_offset
|
||||
|
||||
def set_home_offset(self, motor_id: int, home_position: int):
|
||||
self._disable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write4ByteTxRx(
|
||||
self.portHandler, motor_id, ReadAttribute.HOMING_OFFSET.value, home_position
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
self._enable_torque(motor_id)
|
||||
|
||||
def set_baudrate(self, motor_id: int, baudrate):
|
||||
# translate baudrate into dynamixel baudrate setting id
|
||||
if baudrate == 57600:
|
||||
baudrate_id = 1
|
||||
elif baudrate == 1_000_000:
|
||||
baudrate_id = 3
|
||||
elif baudrate == 2_000_000:
|
||||
baudrate_id = 4
|
||||
elif baudrate == 3_000_000:
|
||||
baudrate_id = 5
|
||||
elif baudrate == 4_000_000:
|
||||
baudrate_id = 6
|
||||
else:
|
||||
raise Exception("baudrate not implemented")
|
||||
|
||||
self._disable_torque(motor_id)
|
||||
dxl_comm_result, dxl_error = self.packetHandler.write1ByteTxRx(
|
||||
self.portHandler, motor_id, ReadAttribute.BAUDRATE.value, baudrate_id
|
||||
)
|
||||
self._process_response(dxl_comm_result, dxl_error, motor_id)
|
||||
|
||||
def _read_value(self, motor_id, attribute: ReadAttribute, num_bytes: int, tries=10):
|
||||
try:
|
||||
if num_bytes == 1:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read1ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
elif num_bytes == 2:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read2ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
elif num_bytes == 4:
|
||||
value, dxl_comm_result, dxl_error = self.packetHandler.read4ByteTxRx(
|
||||
self.portHandler, motor_id, attribute.value
|
||||
)
|
||||
except Exception:
|
||||
if tries == 0:
|
||||
raise Exception
|
||||
else:
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
if dxl_comm_result != COMM_SUCCESS:
|
||||
if tries <= 1:
|
||||
# print("%s" % self.packetHandler.getTxRxResult(dxl_comm_result))
|
||||
raise ConnectionError(f"dxl_comm_result {dxl_comm_result} for servo {motor_id} value {value}")
|
||||
else:
|
||||
print(f"dynamixel read failure for servo {motor_id} trying again with {tries - 1} tries")
|
||||
time.sleep(0.02)
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
elif dxl_error != 0: # # print("%s" % self.packetHandler.getRxPacketError(dxl_error))
|
||||
# raise ConnectionError(f'dxl_error {dxl_error} binary ' + "{0:b}".format(37))
|
||||
if tries == 0 and dxl_error != 128:
|
||||
raise Exception(f"Failed to read value from motor {motor_id} error is {dxl_error}")
|
||||
else:
|
||||
return self._read_value(motor_id, attribute, num_bytes, tries=tries - 1)
|
||||
return value
|
||||
|
||||
def set_home_position(self, motor_id: int):
|
||||
print(f"setting home position for motor {motor_id}")
|
||||
self.set_home_offset(motor_id, 0)
|
||||
current_position = self.read_position(motor_id)
|
||||
print(f"position before {current_position}")
|
||||
self.set_home_offset(motor_id, -current_position)
|
||||
# dynamixel.set_home_offset(motor_id, -4096)
|
||||
# dynamixel.set_home_offset(motor_id, -4294964109)
|
||||
current_position = self.read_position(motor_id)
|
||||
# print(f'signed position {current_position - 2** 32}')
|
||||
print(f"position after {current_position}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dynamixel = Dynamixel.Config(baudrate=1_000_000, device_name="/dev/tty.usbmodem57380045631").instantiate()
|
||||
motor_id = 1
|
||||
pos = dynamixel.read_position(motor_id)
|
||||
for i in range(10):
|
||||
s = time.monotonic()
|
||||
pos = dynamixel.read_position(motor_id)
|
||||
delta = time.monotonic() - s
|
||||
print(f"read position took {delta}")
|
||||
print(f"position {pos}")
|
||||
192
examples/real_robot_example/gym_real_world/gym_environment.py
Normal file
192
examples/real_robot_example/gym_real_world/gym_environment.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from .dynamixel import pos2pwm, pwm2pos
|
||||
from .robot import Robot
|
||||
|
||||
FPS = 30
|
||||
|
||||
CAMERAS_SHAPES = {
|
||||
"images.high": (480, 640, 3),
|
||||
"images.low": (480, 640, 3),
|
||||
}
|
||||
|
||||
CAMERAS_PORTS = {
|
||||
"images.high": "/dev/video6",
|
||||
"images.low": "/dev/video0",
|
||||
}
|
||||
|
||||
LEADER_PORT = "/dev/ttyACM1"
|
||||
FOLLOWER_PORT = "/dev/ttyACM0"
|
||||
|
||||
MockRobot = MagicMock()
|
||||
MockRobot.read_position = MagicMock()
|
||||
MockRobot.read_position.return_value = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
|
||||
|
||||
MockCamera = MagicMock()
|
||||
MockCamera.isOpened = MagicMock(return_value=True)
|
||||
MockCamera.read = MagicMock(return_value=(True, np.zeros((480, 640, 3), dtype=np.uint8)))
|
||||
|
||||
|
||||
def capture_image(cam, cam_width, cam_height):
|
||||
# Capture a single frame
|
||||
_, frame = cam.read()
|
||||
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
# # Define your crop coordinates (top left corner and bottom right corner)
|
||||
# x1, y1 = 400, 0 # Example starting coordinates (top left of the crop rectangle)
|
||||
# x2, y2 = 1600, 900 # Example ending coordinates (bottom right of the crop rectangle)
|
||||
# # Crop the image
|
||||
# image = image[y1:y2, x1:x2]
|
||||
# Resize the image
|
||||
image = cv2.resize(image, (cam_width, cam_height), interpolation=cv2.INTER_AREA)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class RealEnv(gym.Env):
|
||||
metadata = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
record: bool = False,
|
||||
num_joints: int = 6,
|
||||
cameras_shapes: dict = CAMERAS_SHAPES,
|
||||
cameras_ports: dict = CAMERAS_PORTS,
|
||||
follower_port: str = FOLLOWER_PORT,
|
||||
leader_port: str = LEADER_PORT,
|
||||
warmup_steps: int = 100,
|
||||
trigger_torque=70,
|
||||
fps: int = FPS,
|
||||
fps_tolerance: float = 0.1,
|
||||
mock: bool = False,
|
||||
):
|
||||
self.num_joints = num_joints
|
||||
self.cameras_shapes = cameras_shapes
|
||||
self.cameras_ports = cameras_ports
|
||||
self.warmup_steps = warmup_steps
|
||||
assert len(self.cameras_shapes) == len(self.cameras_ports), "Number of cameras and shapes must match."
|
||||
|
||||
self.follower_port = follower_port
|
||||
self.leader_port = leader_port
|
||||
self.record = record
|
||||
self.fps = fps
|
||||
self.fps_tolerance = fps_tolerance
|
||||
|
||||
# Initialize the robot
|
||||
self.follower = Robot(device_name=self.follower_port) if not mock else MockRobot
|
||||
if self.record:
|
||||
self.leader = Robot(device_name=self.leader_port) if not mock else MockRobot
|
||||
self.leader.set_trigger_torque(trigger_torque)
|
||||
|
||||
# Initialize the cameras - sorted by camera names
|
||||
self.cameras = {}
|
||||
for cn, p in sorted(self.cameras_ports.items()):
|
||||
self.cameras[cn] = cv2.VideoCapture(p) if not mock else MockCamera
|
||||
if not self.cameras[cn].isOpened():
|
||||
raise OSError(
|
||||
f"Cannot open camera port {p} for {cn}."
|
||||
f" Make sure the camera is connected and the port is correct."
|
||||
f"Also check you are not spinning several instances of the same environment (eval.batch_size)"
|
||||
)
|
||||
|
||||
# Specify gym action and observation spaces
|
||||
observation_space = {}
|
||||
|
||||
if self.num_joints > 0:
|
||||
observation_space["agent_pos"] = spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
if self.record:
|
||||
observation_space["leader_pos"] = spaces.Box(
|
||||
low=-1000.0,
|
||||
high=1000.0,
|
||||
shape=(num_joints,),
|
||||
dtype=np.float64,
|
||||
)
|
||||
|
||||
if self.cameras_shapes:
|
||||
for cn, hwc_shape in self.cameras_shapes.items():
|
||||
# Assumes images are unsigned int8 in [0,255]
|
||||
observation_space[cn] = spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
# height x width x channels (e.g. 480 x 640 x 3)
|
||||
shape=hwc_shape,
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
self.observation_space = spaces.Dict(observation_space)
|
||||
self.action_space = spaces.Box(low=-1, high=1, shape=(num_joints,), dtype=np.float32)
|
||||
|
||||
self._observation = {}
|
||||
self._terminated = False
|
||||
self.timestamps = []
|
||||
|
||||
def _get_obs(self):
|
||||
qpos = self.follower.read_position()
|
||||
self._observation["agent_pos"] = pwm2pos(qpos)
|
||||
for cn, c in self.cameras.items():
|
||||
self._observation[cn] = capture_image(c, self.cameras_shapes[cn][1], self.cameras_shapes[cn][0])
|
||||
|
||||
if self.record:
|
||||
action = self.leader.read_position()
|
||||
self._observation["leader_pos"] = pwm2pos(action)
|
||||
|
||||
def reset(self, seed: int | None = None):
|
||||
# Reset the robot and sync the leader and follower if we are recording
|
||||
for _ in range(self.warmup_steps):
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
self._terminated = False
|
||||
info = {}
|
||||
self.timestamps = []
|
||||
return self._observation, info
|
||||
|
||||
def step(self, action: np.ndarray = None):
|
||||
if self.timestamps:
|
||||
# wait the right amount of time to stay at the desired fps
|
||||
time.sleep(max(0, 1 / self.fps - (time.time() - self.timestamps[-1])))
|
||||
|
||||
self.timestamps.append(time.time())
|
||||
|
||||
# Get the observation
|
||||
self._get_obs()
|
||||
if self.record:
|
||||
# Teleoperate the leader
|
||||
self.follower.set_goal_pos(pos2pwm(self._observation["leader_pos"]))
|
||||
else:
|
||||
# Apply the action to the follower
|
||||
self.follower.set_goal_pos(pos2pwm(action))
|
||||
|
||||
reward = 0
|
||||
terminated = truncated = self._terminated
|
||||
info = {"timestamp": self.timestamps[-1] - self.timestamps[0], "fps_error": False}
|
||||
|
||||
# Check if we are able to keep up with the desired fps
|
||||
if len(self.timestamps) > 1 and (self.timestamps[-1] - self.timestamps[-2]) > 1 / (
|
||||
self.fps - self.fps_tolerance
|
||||
):
|
||||
print(
|
||||
f"Error: recording fps {1 / (self.timestamps[-1] - self.timestamps[-2]):.5f} is lower"
|
||||
f" than min admited fps {(self.fps - self.fps_tolerance):.5f}"
|
||||
f" at frame {len(self.timestamps)}"
|
||||
)
|
||||
info["fps_error"] = True
|
||||
|
||||
return self._observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self): ...
|
||||
|
||||
def close(self):
|
||||
self.follower._disable_torque()
|
||||
if self.record:
|
||||
self.leader._disable_torque()
|
||||
168
examples/real_robot_example/gym_real_world/robot.py
Normal file
168
examples/real_robot_example/gym_real_world/robot.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# ruff: noqa
|
||||
"""From Alexander Koch low_cost_robot codebase at https://github.com/AlexanderKoch-Koch/low_cost_robot
|
||||
Class to control the robot using dynamixel servos.
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from dynamixel_sdk import DXL_HIBYTE, DXL_HIWORD, DXL_LOBYTE, DXL_LOWORD, GroupSyncRead, GroupSyncWrite
|
||||
|
||||
from .dynamixel import Dynamixel, OperatingMode, ReadAttribute
|
||||
|
||||
|
||||
class MotorControlType(Enum):
|
||||
PWM = auto()
|
||||
POSITION_CONTROL = auto()
|
||||
DISABLED = auto()
|
||||
UNKNOWN = auto()
|
||||
|
||||
|
||||
class Robot:
|
||||
def __init__(self, device_name: str, baudrate=1_000_000, servo_ids=[1, 2, 3, 4, 5, 6]) -> None:
|
||||
self.servo_ids = servo_ids
|
||||
self.dynamixel = Dynamixel.Config(baudrate=baudrate, device_name=device_name).instantiate()
|
||||
self._init_motors()
|
||||
|
||||
def _init_motors(self):
|
||||
self.position_reader = GroupSyncRead(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.POSITION.value, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.position_reader.addParam(id)
|
||||
|
||||
self.velocity_reader = GroupSyncRead(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, ReadAttribute.VELOCITY.value, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.velocity_reader.addParam(id)
|
||||
|
||||
self.pos_writer = GroupSyncWrite(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_POSITION, 4
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.pos_writer.addParam(id, [2048])
|
||||
|
||||
self.pwm_writer = GroupSyncWrite(
|
||||
self.dynamixel.portHandler, self.dynamixel.packetHandler, self.dynamixel.ADDR_GOAL_PWM, 2
|
||||
)
|
||||
for id in self.servo_ids:
|
||||
self.pwm_writer.addParam(id, [2048])
|
||||
self._disable_torque()
|
||||
self.motor_control_state = MotorControlType.DISABLED
|
||||
|
||||
def read_position(self, tries=2):
|
||||
"""
|
||||
Reads the joint positions of the robot. 2048 is the center position. 0 and 4096 are 180 degrees in each direction.
|
||||
:param tries: maximum number of tries to read the position
|
||||
:return: list of joint positions in range [0, 4096]
|
||||
"""
|
||||
result = self.position_reader.txRxPacket()
|
||||
if result != 0:
|
||||
if tries > 0:
|
||||
return self.read_position(tries=tries - 1)
|
||||
else:
|
||||
print("failed to read position!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
positions = []
|
||||
for id in self.servo_ids:
|
||||
position = self.position_reader.getData(id, ReadAttribute.POSITION.value, 4)
|
||||
if position > 2**31:
|
||||
position -= 2**32
|
||||
positions.append(position)
|
||||
return np.array(positions)
|
||||
|
||||
def read_velocity(self):
|
||||
"""
|
||||
Reads the joint velocities of the robot.
|
||||
:return: list of joint velocities,
|
||||
"""
|
||||
self.velocity_reader.txRxPacket()
|
||||
velocties = []
|
||||
for id in self.servo_ids:
|
||||
velocity = self.velocity_reader.getData(id, ReadAttribute.VELOCITY.value, 4)
|
||||
if velocity > 2**31:
|
||||
velocity -= 2**32
|
||||
velocties.append(velocity)
|
||||
return np.array(velocties)
|
||||
|
||||
def set_goal_pos(self, action):
|
||||
"""
|
||||
:param action: list or numpy array of target joint positions in range [0, 4096]
|
||||
"""
|
||||
if self.motor_control_state is not MotorControlType.POSITION_CONTROL:
|
||||
self._set_position_control()
|
||||
for i, motor_id in enumerate(self.servo_ids):
|
||||
data_write = [
|
||||
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_LOBYTE(DXL_HIWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_HIWORD(action[i])),
|
||||
]
|
||||
self.pos_writer.changeParam(motor_id, data_write)
|
||||
|
||||
self.pos_writer.txPacket()
|
||||
|
||||
def set_pwm(self, action):
|
||||
"""
|
||||
Sets the pwm values for the servos.
|
||||
:param action: list or numpy array of pwm values in range [0, 885]
|
||||
"""
|
||||
if self.motor_control_state is not MotorControlType.PWM:
|
||||
self._set_pwm_control()
|
||||
for i, motor_id in enumerate(self.servo_ids):
|
||||
data_write = [
|
||||
DXL_LOBYTE(DXL_LOWORD(action[i])),
|
||||
DXL_HIBYTE(DXL_LOWORD(action[i])),
|
||||
]
|
||||
self.pwm_writer.changeParam(motor_id, data_write)
|
||||
|
||||
self.pwm_writer.txPacket()
|
||||
|
||||
def set_trigger_torque(self, torque: int):
|
||||
"""
|
||||
Sets a constant torque torque for the last servo in the chain. This is useful for the trigger of the leader arm
|
||||
"""
|
||||
self.dynamixel._enable_torque(self.servo_ids[-1])
|
||||
self.dynamixel.set_pwm_value(self.servo_ids[-1], torque)
|
||||
|
||||
def limit_pwm(self, limit: Union[int, list, np.ndarray]):
|
||||
"""
|
||||
Limits the pwm values for the servos in for position control
|
||||
@param limit: 0 ~ 885
|
||||
@return:
|
||||
"""
|
||||
if isinstance(limit, int):
|
||||
limits = [
|
||||
limit,
|
||||
] * 5
|
||||
else:
|
||||
limits = limit
|
||||
self._disable_torque()
|
||||
for motor_id, limit in zip(self.servo_ids, limits, strict=False):
|
||||
self.dynamixel.set_pwm_limit(motor_id, limit)
|
||||
self._enable_torque()
|
||||
|
||||
def _disable_torque(self):
|
||||
print(f"disabling torque for servos {self.servo_ids}")
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel._disable_torque(motor_id)
|
||||
|
||||
def _enable_torque(self):
|
||||
print(f"enabling torque for servos {self.servo_ids}")
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel._enable_torque(motor_id)
|
||||
|
||||
def _set_pwm_control(self):
|
||||
self._disable_torque()
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel.set_operating_mode(motor_id, OperatingMode.PWM)
|
||||
self._enable_torque()
|
||||
self.motor_control_state = MotorControlType.PWM
|
||||
|
||||
def _set_position_control(self):
|
||||
self._disable_torque()
|
||||
for motor_id in self.servo_ids:
|
||||
self.dynamixel.set_operating_mode(motor_id, OperatingMode.POSITION)
|
||||
self._enable_torque()
|
||||
self.motor_control_state = MotorControlType.POSITION_CONTROL
|
||||
237
examples/real_robot_example/record_training_data.py
Normal file
237
examples/real_robot_example/record_training_data.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""This script demonstrates how to record a LeRobot dataset of training data
|
||||
using a very simple gym environment (see in examples/real_robot_example/gym_real_world/gym_environment.py).
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset, Features, Sequence, Value
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, DATA_DIR, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, save_images_concurrently
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
from lerobot.scripts.push_dataset_to_hub import push_meta_data_to_hub, push_videos_to_hub, save_meta_data
|
||||
|
||||
# parse the repo_id name via command line
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--repo-id", type=str, default="thomwolf/blue_red_sort")
|
||||
parser.add_argument("--num-episodes", type=int, default=2)
|
||||
parser.add_argument("--num-frames", type=int, default=400)
|
||||
parser.add_argument("--num-workers", type=int, default=16)
|
||||
parser.add_argument("--keep-last", action="store_true")
|
||||
parser.add_argument("--data_dir", type=str, default=None)
|
||||
parser.add_argument("--push-to-hub", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30, help="Frames per second of the recording.")
|
||||
parser.add_argument(
|
||||
"--fps_tolerance",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Tolerance in fps for the recording before dropping episodes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision", type=str, default=CODEBASE_VERSION, help="Codebase version used to generate the dataset."
|
||||
)
|
||||
parser.add_argument("--gym-config", type=str, default=None, help="Path to the gym config file.")
|
||||
parser.add_argument("--mock_robot", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo_id = args.repo_id
|
||||
num_episodes = args.num_episodes
|
||||
num_frames = args.num_frames
|
||||
revision = args.revision
|
||||
fps = args.fps
|
||||
fps_tolerance = args.fps_tolerance
|
||||
|
||||
out_data = DATA_DIR / repo_id if args.data_dir is None else Path(args.data_dir)
|
||||
|
||||
# During data collection, frames are stored as png images in `images_dir`
|
||||
images_dir = out_data / "images"
|
||||
# After data collection, png images of each episode are encoded into a mp4 file stored in `videos_dir`
|
||||
videos_dir = out_data / "videos"
|
||||
meta_data_dir = out_data / "meta_data"
|
||||
|
||||
gym_config = None
|
||||
if args.config is not None:
|
||||
gym_config = OmegaConf.load(args.config)
|
||||
|
||||
# Create image and video directories
|
||||
if not os.path.exists(images_dir):
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
if not os.path.exists(videos_dir):
|
||||
os.makedirs(videos_dir, exist_ok=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create the gym environment - check the kwargs in gym_real_world/gym_environment.py
|
||||
gym_handle = "gym_real_world/RealEnv-v0"
|
||||
gym_kwargs = {}
|
||||
if gym_config is not None:
|
||||
gym_kwargs = OmegaConf.to_container(gym_config.gym_kwargs)
|
||||
env = gym.make(
|
||||
gym_handle, disable_env_checker=True, record=True, fps=fps, fps_tolerance=fps_tolerance, mock=True
|
||||
)
|
||||
|
||||
ep_dicts = []
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
ep_fps = []
|
||||
id_from = 0
|
||||
id_to = 0
|
||||
os.system('spd-say "gym environment created"')
|
||||
|
||||
ep_idx = 0
|
||||
while ep_idx < num_episodes:
|
||||
# bring the follower to the leader and start camera
|
||||
env.reset()
|
||||
|
||||
os.system(f'spd-say "go {ep_idx}"')
|
||||
# init buffers
|
||||
obs_replay = {k: [] for k in env.observation_space}
|
||||
|
||||
drop_episode = False
|
||||
timestamps = []
|
||||
for _ in tqdm(range(num_frames)):
|
||||
# Apply the next action
|
||||
observation, _, _, _, info = env.step(action=None)
|
||||
# images_stacked = np.hstack(list(observation['pixels'].values()))
|
||||
# images_stacked = cv2.cvtColor(images_stacked, cv2.COLOR_RGB2BGR)
|
||||
# cv2.imshow('frame', images_stacked)
|
||||
|
||||
if info["fps_error"]:
|
||||
os.system(f'spd-say "Error fps too low, dropping episode {ep_idx}"')
|
||||
drop_episode = True
|
||||
break
|
||||
|
||||
# store data
|
||||
for key in observation:
|
||||
obs_replay[key].append(copy.deepcopy(observation[key]))
|
||||
timestamps.append(info["timestamp"])
|
||||
|
||||
# if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
# break
|
||||
|
||||
os.system('spd-say "stop"')
|
||||
|
||||
if not drop_episode:
|
||||
os.system(f'spd-say "saving episode {ep_idx}"')
|
||||
ep_dict = {}
|
||||
# store images in png and create the video
|
||||
for img_key in env.cameras:
|
||||
save_images_concurrently(
|
||||
obs_replay[img_key],
|
||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||
args.num_workers,
|
||||
)
|
||||
fname = f"{img_key}_episode_{ep_idx:06d}.mp4"
|
||||
# store the reference to the video frame
|
||||
ep_dict[f"observation.{img_key}"] = [
|
||||
{"path": f"videos/{fname}", "timestamp": tstp} for tstp in timestamps
|
||||
]
|
||||
|
||||
state = torch.tensor(np.array(obs_replay["agent_pos"]))
|
||||
action = torch.tensor(np.array(obs_replay["leader_pos"]))
|
||||
next_done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
next_done[-1] = True
|
||||
|
||||
ep_dict["observation.state"] = state
|
||||
ep_dict["action"] = action
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["timestamp"] = torch.tensor(timestamps)
|
||||
ep_dict["next.done"] = next_done
|
||||
ep_fps.append(num_frames / timestamps[-1])
|
||||
ep_dicts.append(ep_dict)
|
||||
print(f"Episode {ep_idx} done, fps: {ep_fps[-1]:.2f}")
|
||||
|
||||
episode_data_index["from"].append(id_from)
|
||||
episode_data_index["to"].append(
|
||||
id_from + num_frames if args.keep_last else id_from + num_frames - 1
|
||||
)
|
||||
|
||||
id_to = id_from + num_frames if args.keep_last else id_from + num_frames - 1
|
||||
id_from = id_to
|
||||
|
||||
ep_idx += 1
|
||||
|
||||
env.close()
|
||||
|
||||
os.system('spd-say "encode video frames"')
|
||||
for ep_idx in range(num_episodes):
|
||||
for img_key in env.cameras:
|
||||
# If necessary, we may want to encode the video
|
||||
# with variable frame rate: https://superuser.com/questions/1661901/encoding-video-from-vfr-still-images
|
||||
encode_video_frames(
|
||||
images_dir / f"{img_key}_episode_{ep_idx:06d}",
|
||||
videos_dir / f"{img_key}_episode_{ep_idx:06d}.mp4",
|
||||
ep_fps[ep_idx],
|
||||
)
|
||||
|
||||
os.system('spd-say "concatenate episodes"')
|
||||
data_dict = concatenate_episodes(
|
||||
ep_dicts, drop_episodes_last_frame=not args.keep_last
|
||||
) # Since our fps varies we are sometimes off tolerance for the last frame
|
||||
|
||||
features = {}
|
||||
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
features[key] = VideoFrame()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
features["next.done"] = Value(dtype="bool", id=None)
|
||||
features["index"] = Value(dtype="int64", id=None)
|
||||
|
||||
hf_dataset = Dataset.from_dict(data_dict, features=Features(features))
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
|
||||
info = {
|
||||
"fps": sum(ep_fps) / len(ep_fps), # to have a good tolerance in data processing for the slowest video
|
||||
"video": 1,
|
||||
}
|
||||
|
||||
os.system('spd-say "from preloaded"')
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
version=revision,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
os.system('spd-say "compute stats"')
|
||||
stats = compute_stats(lerobot_dataset)
|
||||
|
||||
os.system('spd-say "save to disk"')
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(out_data / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision="main")
|
||||
hf_dataset.push_to_hub(repo_id, token=True, revision=revision)
|
||||
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision=revision)
|
||||
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
push_videos_to_hub(repo_id, videos_dir, revision=revision)
|
||||
60
examples/real_robot_example/run_policy.py
Normal file
60
examples/real_robot_example/run_policy.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import gym_real_world # noqa: F401
|
||||
import gymnasium as gym # noqa: F401
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
from huggingface_hub.utils._validators import HFValidationError
|
||||
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.scripts.eval import eval
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"overrides",
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
pretrained_policy_path = Path(
|
||||
snapshot_download(args.pretrained_policy_name_or_path, revision=args.revision)
|
||||
)
|
||||
except (HFValidationError, RepositoryNotFoundError) as e:
|
||||
if isinstance(e, HFValidationError):
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
|
||||
)
|
||||
else:
|
||||
error_message = (
|
||||
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
|
||||
)
|
||||
|
||||
logging.warning(f"{error_message} Treating it as a local directory.")
|
||||
pretrained_policy_path = Path(args.pretrained_policy_name_or_path)
|
||||
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
|
||||
raise ValueError(
|
||||
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
|
||||
"repo ID, nor is it an existing local directory."
|
||||
)
|
||||
|
||||
eval(pretrained_policy_path=pretrained_policy_path, config_overrides=args.overrides)
|
||||
19
examples/real_robot_example/train_config/env/gym_real_world.yaml
vendored
Normal file
19
examples/real_robot_example/train_config/env/gym_real_world.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: RealEnv-v0
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
real_world: true
|
||||
gym:
|
||||
cameras_shapes:
|
||||
images.high: [480, 640, 3]
|
||||
images.low: [480, 640, 3]
|
||||
cameras_ports:
|
||||
images.high: /dev/video6
|
||||
images.low: /dev/video0
|
||||
19
examples/real_robot_example/train_config/env/gym_real_world_debug.yaml
vendored
Normal file
19
examples/real_robot_example/train_config/env/gym_real_world_debug.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: RealEnv-v0
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
episode_length: 200
|
||||
real_world: true
|
||||
gym:
|
||||
cameras_shapes:
|
||||
images.top: [480, 640, 3]
|
||||
images.front: [480, 640, 3]
|
||||
cameras_ports:
|
||||
images.top: /dev/video6
|
||||
images.front: /dev/video0
|
||||
@@ -0,0 +1,103 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||
# low) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: ???
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.high:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.low:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 1000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
batch_size: 1
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.high: [3, 480, 640]
|
||||
observation.images.low: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.high: mean_std
|
||||
observation.images.low: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -0,0 +1,103 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_real.yaml` to train on real-world Aloha/Aloha2 datasets.
|
||||
# Compared to `act.yaml`, it contains 4 cameras (i.e. right_wrist, left_wrist, images,
|
||||
# front) instead of 1 camera (i.e. top). Also, `training.eval_freq` is set to -1. This config is used
|
||||
# to evaluate checkpoints at a certain frequency of training steps. When it is set to -1, it deactivates evaluation.
|
||||
# This is because real-world evaluation is done through [dora-lerobot](https://github.com/dora-rs/dora-lerobot).
|
||||
# Look at its README for more information on how to evaluate a checkpoint in the real-world.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_real \
|
||||
# env=aloha_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: ???
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.top:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.front:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 1000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 1000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(1, ${policy.chunk_size} + 1)]"
|
||||
|
||||
eval:
|
||||
n_episodes: 1
|
||||
batch_size: 1
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100 # chunk_size
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.top: [3, 480, 640]
|
||||
observation.images.front: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.top: mean_std
|
||||
observation.images.front: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -56,7 +56,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
)
|
||||
|
||||
# A soft check to warn if the environment matches the dataset. Don't check if we are using a real world env (dora).
|
||||
if cfg.env.name != "dora":
|
||||
if not cfg.env.real_world:
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
dataset_repo_ids = [cfg.dataset_repo_id] # single dataset
|
||||
else:
|
||||
|
||||
@@ -43,9 +43,6 @@ def get_cameras(hdf5_data):
|
||||
|
||||
|
||||
def check_format(raw_dir) -> bool:
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_paths = list(raw_dir.glob("episode_*.hdf5"))
|
||||
assert len(hdf5_paths) != 0
|
||||
for hdf5_path in hdf5_paths:
|
||||
@@ -62,17 +59,15 @@ def check_format(raw_dir) -> bool:
|
||||
for camera in get_cameras(data):
|
||||
assert num_frames == data[f"/observations/images/{camera}"].shape[0]
|
||||
|
||||
if compressed_images:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
# ndim 2 when image are compressed and 4 when uncompressed
|
||||
assert data[f"/observations/images/{camera}"].ndim in [2, 4]
|
||||
if data[f"/observations/images/{camera}"].ndim == 4:
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
# only frames from simulation are uncompressed
|
||||
compressed_images = "sim" not in raw_dir.name
|
||||
|
||||
hdf5_files = list(raw_dir.glob("*.hdf5"))
|
||||
ep_dicts = []
|
||||
@@ -99,7 +94,7 @@ def load_from_raw(raw_dir, out_dir, fps, video, debug):
|
||||
for camera in get_cameras(ep):
|
||||
img_key = f"observation.images.{camera}"
|
||||
|
||||
if compressed_images:
|
||||
if ep[f"/observations/images/{camera}"].ndim == 2:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
|
||||
@@ -21,19 +21,24 @@ import PIL
|
||||
import torch
|
||||
|
||||
|
||||
def concatenate_episodes(ep_dicts):
|
||||
def concatenate_episodes(ep_dicts, drop_episodes_last_frame=False):
|
||||
data_dict = {}
|
||||
|
||||
keys = ep_dicts[0].keys()
|
||||
for key in keys:
|
||||
if torch.is_tensor(ep_dicts[0][key][0]):
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
if drop_episodes_last_frame:
|
||||
data_dict[key] = torch.cat([ep_dict[key][:-1] for ep_dict in ep_dicts])
|
||||
else:
|
||||
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
|
||||
else:
|
||||
if key not in data_dict:
|
||||
data_dict[key] = []
|
||||
for ep_dict in ep_dicts:
|
||||
for x in ep_dict[key]:
|
||||
data_dict[key].append(x)
|
||||
if drop_episodes_last_frame:
|
||||
data_dict[key].pop()
|
||||
|
||||
total_frames = data_dict["frame_index"].shape[0]
|
||||
data_dict["index"] = torch.arange(0, total_frames, 1)
|
||||
|
||||
@@ -29,10 +29,12 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
|
||||
if isinstance(observations["pixels"], dict):
|
||||
if "pixels" in observations and isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
elif "pixels" in observations and isinstance(observations["pixels"], np.ndarray):
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
else:
|
||||
imgs = {f"observation.{key}": img for key, img in observations.items() if "images" in key}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
@@ -129,7 +129,9 @@ class ACTConfig:
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
# As a consequence we also remove the final, unused layer normalization, by default
|
||||
n_decoder_layers: int = 1
|
||||
decoder_norm: bool = False
|
||||
# VAE.
|
||||
use_vae: bool = True
|
||||
latent_dim: int = 32
|
||||
|
||||
@@ -315,8 +315,14 @@ class ACT(nn.Module):
|
||||
pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D)
|
||||
|
||||
# Forward pass through VAE encoder to get the latent PDF parameters.
|
||||
cls_joint_is_pad = torch.full((batch_size, 2), False).to(
|
||||
batch["observation.state"].device
|
||||
) # False: not a padding
|
||||
key_padding_mask = torch.cat([cls_joint_is_pad, batch["action_is_pad"]], axis=1) # (bs, seq+1)
|
||||
cls_token_out = self.vae_encoder(
|
||||
vae_encoder_input.permute(1, 0, 2), pos_embed=pos_embed.permute(1, 0, 2)
|
||||
vae_encoder_input.permute(1, 0, 2),
|
||||
pos_embed=pos_embed.permute(1, 0, 2),
|
||||
key_padding_mask=key_padding_mask,
|
||||
)[0] # select the class token, with shape (B, D)
|
||||
latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out)
|
||||
mu = latent_pdf_params[:, : self.config.latent_dim]
|
||||
@@ -402,9 +408,11 @@ class ACTEncoder(nn.Module):
|
||||
self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(config.n_encoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor, pos_embed: Tensor | None = None) -> Tensor:
|
||||
def forward(
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed)
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
@@ -427,12 +435,14 @@ class ACTEncoderLayer(nn.Module):
|
||||
self.activation = get_activation_fn(config.feedforward_activation)
|
||||
self.pre_norm = config.pre_norm
|
||||
|
||||
def forward(self, x, pos_embed: Tensor | None = None) -> Tensor:
|
||||
def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor:
|
||||
skip = x
|
||||
if self.pre_norm:
|
||||
x = self.norm1(x)
|
||||
q = k = x if pos_embed is None else x + pos_embed
|
||||
x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights
|
||||
x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask)[
|
||||
0
|
||||
] # select just the output, not the attention weights
|
||||
x = skip + self.dropout1(x)
|
||||
if self.pre_norm:
|
||||
skip = x
|
||||
@@ -452,7 +462,10 @@ class ACTDecoder(nn.Module):
|
||||
"""Convenience module for running multiple decoder layers followed by normalization."""
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)])
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
if config.decoder_norm:
|
||||
self.norm = nn.LayerNorm(config.dim_model)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -465,8 +478,7 @@ class ACTDecoder(nn.Module):
|
||||
x = layer(
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -50,6 +50,8 @@ eval:
|
||||
batch_size: 1
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
use_async_envs: false
|
||||
# Specify the number of episodes to render during evaluation.
|
||||
max_episodes_rendered: 10
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
|
||||
1
lerobot/configs/env/aloha.yaml
vendored
1
lerobot/configs/env/aloha.yaml
vendored
@@ -9,6 +9,7 @@ env:
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
1
lerobot/configs/env/dora_aloha_real.yaml
vendored
1
lerobot/configs/env/dora_aloha_real.yaml
vendored
@@ -9,5 +9,6 @@ env:
|
||||
action_dim: 14
|
||||
fps: ${fps}
|
||||
episode_length: 400
|
||||
real_world: true
|
||||
gym:
|
||||
fps: ${fps}
|
||||
|
||||
1
lerobot/configs/env/pusht.yaml
vendored
1
lerobot/configs/env/pusht.yaml
vendored
@@ -10,6 +10,7 @@ env:
|
||||
action_dim: 2
|
||||
fps: ${fps}
|
||||
episode_length: 300
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
1
lerobot/configs/env/xarm.yaml
vendored
1
lerobot/configs/env/xarm.yaml
vendored
@@ -10,6 +10,7 @@ env:
|
||||
action_dim: 4
|
||||
fps: ${fps}
|
||||
episode_length: 25
|
||||
real_world: false
|
||||
gym:
|
||||
obs_type: pixels_agent_pos
|
||||
render_mode: rgb_array
|
||||
|
||||
@@ -44,6 +44,7 @@ https://huggingface.co/lerobot/diffusion_pusht/tree/main.
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
@@ -164,7 +165,10 @@ def rollout(
|
||||
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
|
||||
# available of none of the envs finished.
|
||||
if "final_info" in info:
|
||||
successes = [info["is_success"] if info is not None else False for info in info["final_info"]]
|
||||
successes = [
|
||||
i["is_success"] if (i is not None and "is_success" in i) else False
|
||||
for i in info["final_info"]
|
||||
]
|
||||
else:
|
||||
successes = [False] * env.num_envs
|
||||
|
||||
@@ -516,6 +520,7 @@ def eval(
|
||||
out_dir = (
|
||||
f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
|
||||
)
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
@@ -545,7 +550,7 @@ def eval(
|
||||
env,
|
||||
policy,
|
||||
hydra_cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
max_episodes_rendered=hydra_cfg.eval.max_episodes_rendered,
|
||||
video_dir=Path(out_dir) / "eval",
|
||||
start_seed=hydra_cfg.seed,
|
||||
enable_progbar=True,
|
||||
|
||||
@@ -406,7 +406,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
step += 1
|
||||
|
||||
eval_env.close()
|
||||
if cfg.training.eval_freq > 0:
|
||||
eval_env.close()
|
||||
logging.info("End of training")
|
||||
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
||||
return text
|
||||
|
||||
|
||||
def _run_script(path):
|
||||
subprocess.run([sys.executable, path], check=True)
|
||||
def _run_script(path, args=None):
|
||||
subprocess.run([sys.executable, path] + args if args is not None else [], check=True)
|
||||
|
||||
|
||||
def _read_file(path):
|
||||
@@ -126,3 +126,22 @@ def test_examples_basic2_basic3_advanced1():
|
||||
# Restore stdout to its original state
|
||||
sys.stdout = sys.__stdout__
|
||||
assert "Average loss on validation set" in printed_output
|
||||
|
||||
|
||||
def test_real_world_recording():
|
||||
path = "examples/real_robot_example/record_training_data.py"
|
||||
_run_script(
|
||||
path,
|
||||
[
|
||||
"--data_dir",
|
||||
"outputs/examples",
|
||||
"--repo-id",
|
||||
"real_world_debug",
|
||||
"--num-episodes",
|
||||
"2",
|
||||
"--num-frames",
|
||||
"10",
|
||||
"--mock-robot",
|
||||
],
|
||||
)
|
||||
assert Path("outputs/examples/real_world_debug/video/episode_0.mp4").exists()
|
||||
|
||||
Reference in New Issue
Block a user