Compare commits
204 Commits
user/miche
...
test/add_c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79498ab967 | ||
|
|
cb10f97ccc | ||
|
|
2bb73ac431 | ||
|
|
9afc4b771c | ||
|
|
f71e224023 | ||
|
|
889de7c415 | ||
|
|
3539251b18 | ||
|
|
1f210bc8a3 | ||
|
|
d70bc4bde9 | ||
|
|
bdbca09cb2 | ||
|
|
e0b292ab51 | ||
|
|
f960f4d8d4 | ||
|
|
9e57ec7837 | ||
|
|
0a7f51f0da | ||
|
|
4ca92a28e9 | ||
|
|
0464dc91b3 | ||
|
|
d32daebf75 | ||
|
|
27cb0c40bd | ||
|
|
12abc9ca86 | ||
|
|
4005065223 | ||
|
|
443fed216c | ||
|
|
42a87e7211 | ||
|
|
034171a89a | ||
|
|
782dff1163 | ||
|
|
8924ccbbab | ||
|
|
792c3d961d | ||
|
|
e998dddcfa | ||
|
|
99c0938b42 | ||
|
|
716029b1e3 | ||
|
|
3848a8f9aa | ||
|
|
f7672e14c7 | ||
|
|
e393af2d88 | ||
|
|
0dcb2caba8 | ||
|
|
4679725957 | ||
|
|
57319062aa | ||
|
|
078f59bfd1 | ||
|
|
36fcea2002 | ||
|
|
2971bdfed5 | ||
|
|
28cd3a6f3a | ||
|
|
c0570b3003 | ||
|
|
eeb8490016 | ||
|
|
854b78975a | ||
|
|
e55d2ffe50 | ||
|
|
1ebd81552c | ||
|
|
65569ba90e | ||
|
|
79293800f1 | ||
|
|
bc765f9e95 | ||
|
|
201311503f | ||
|
|
8cc0232e73 | ||
|
|
6bfcc18e73 | ||
|
|
e096754d14 | ||
|
|
02803f545d | ||
|
|
8503e8e166 | ||
|
|
d6007c6e7d | ||
|
|
50963fcf13 | ||
|
|
051a52a4ce | ||
|
|
2292b514aa | ||
|
|
1f1a01a798 | ||
|
|
faa476f0d2 | ||
|
|
5130b69ece | ||
|
|
aed85241b7 | ||
|
|
21c3ac42ee | ||
|
|
2d3a5fb2be | ||
|
|
a631e4c11c | ||
|
|
222d6f104e | ||
|
|
7a3b424cd3 | ||
|
|
af295fadb5 | ||
|
|
9644e2b086 | ||
|
|
6ccf083127 | ||
|
|
bb774e7acd | ||
|
|
dcbbeab80b | ||
|
|
b71ac34214 | ||
|
|
c237d1379e | ||
|
|
cf963eb1b0 | ||
|
|
4293b6a4fb | ||
|
|
7a75bb9f61 | ||
|
|
0c1d4cb323 | ||
|
|
c6212d585d | ||
|
|
7c8ab8e2d6 | ||
|
|
1de75c46c0 | ||
|
|
4ad109cff8 | ||
|
|
8994252019 | ||
|
|
9832daf08d | ||
|
|
39d8f45810 | ||
|
|
30fcd3d417 | ||
|
|
039b437ef0 | ||
|
|
7582a0a2b0 | ||
|
|
25388d0947 | ||
|
|
7152bc8aa7 | ||
|
|
5b46dc0b6a | ||
|
|
4273f1f384 | ||
|
|
97194bf7f3 | ||
|
|
0ac026b521 | ||
|
|
ff7cfdaf40 | ||
|
|
57c97762e1 | ||
|
|
a38bb15e79 | ||
|
|
3ceaee999d | ||
|
|
d485dc1313 | ||
|
|
329d103453 | ||
|
|
9f46a3d8f9 | ||
|
|
c9ca9e4316 | ||
|
|
5a57e6f4a7 | ||
|
|
a2f5c34625 | ||
|
|
1f1e1bcfe8 | ||
|
|
e047074825 | ||
|
|
c2e761437d | ||
|
|
fedac994c3 | ||
|
|
7d558d058e | ||
|
|
1d3e1cbdbd | ||
|
|
0ccc957d5c | ||
|
|
a4d487bc1d | ||
|
|
8ca03a7255 | ||
|
|
f2ed2bfb2f | ||
|
|
40675ec76c | ||
|
|
9e34c1d731 | ||
|
|
857f335be9 | ||
|
|
fc4a95f187 | ||
|
|
4fe1880887 | ||
|
|
6fa859fa19 | ||
|
|
2abfa5838d | ||
|
|
3d119c0ccb | ||
|
|
a32081757d | ||
|
|
56c04ffc53 | ||
|
|
715d4557af | ||
|
|
6541982dff | ||
|
|
43bc9404bb | ||
|
|
375499c323 | ||
|
|
17a4447cef | ||
|
|
287dc13d96 | ||
|
|
02a1cf6a4e | ||
|
|
34cd1e47bf | ||
|
|
74d56834af | ||
|
|
dd80dbb4cd | ||
|
|
bc020ee0a4 | ||
|
|
a15767aff1 | ||
|
|
9af0a9bf37 | ||
|
|
e2c8bc6948 | ||
|
|
2c68c6ca40 | ||
|
|
dd1f33e5ed | ||
|
|
2c1bb766ff | ||
|
|
c1c71fb994 | ||
|
|
2d56f35071 | ||
|
|
64ce2669ca | ||
|
|
f527adf7a9 | ||
|
|
6a77189f50 | ||
|
|
e4a6d035f9 | ||
|
|
794f6e00fc | ||
|
|
97494c6a39 | ||
|
|
9358d334c7 | ||
|
|
c85a9253e7 | ||
|
|
8d659a6aa9 | ||
|
|
f6a2396484 | ||
|
|
7a7af82e35 | ||
|
|
7f23972f3f | ||
|
|
3362b665e6 | ||
|
|
eeeccdba53 | ||
|
|
bd5b181dfd | ||
|
|
858678786a | ||
|
|
0f972661e1 | ||
|
|
2e9b144c56 | ||
|
|
fa8ba9e4e2 | ||
|
|
2037cc0219 | ||
|
|
5006da72ff | ||
|
|
ad0bacbfe4 | ||
|
|
e33ca2c980 | ||
|
|
f0505e81cc | ||
|
|
1f7ddc1d76 | ||
|
|
ce63cfdb25 | ||
|
|
d6f1359e69 | ||
|
|
2357d4aceb | ||
|
|
d6ccdc222c | ||
|
|
9bd0788131 | ||
|
|
1ae62c28f7 | ||
|
|
baf6e66c3d | ||
|
|
a065bd61ae | ||
|
|
5dc3c74e64 | ||
|
|
4214b01703 | ||
|
|
b974e5541f | ||
|
|
fd64dc84ae | ||
|
|
06988b2135 | ||
|
|
7ed7570b17 | ||
|
|
e2d13ba7e4 | ||
|
|
f6c1049474 | ||
|
|
2b24feb604 | ||
|
|
a13e49073c | ||
|
|
2c7e0f17b6 | ||
|
|
418866007e | ||
|
|
5ab418dbeb | ||
|
|
95f61ee9d4 | ||
|
|
ac89c8d226 | ||
|
|
d75d904e43 | ||
|
|
ea4d8d990c | ||
|
|
c93cbb8311 | ||
|
|
c0137e89b9 | ||
|
|
3111ba78ad | ||
|
|
3d3a176940 | ||
|
|
212c6095a2 | ||
|
|
48469ec674 | ||
|
|
c7dfd32b43 | ||
|
|
731fb6ebaf | ||
|
|
13e124302f | ||
|
|
59bdd29106 | ||
|
|
124829104b | ||
|
|
21cd2940a9 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
.dev
|
||||
# Logging
|
||||
logs
|
||||
tmp
|
||||
@@ -26,7 +26,6 @@ outputs
|
||||
|
||||
# VS Code
|
||||
.vscode
|
||||
.devcontainer
|
||||
|
||||
# HPC
|
||||
nautilus/*.yaml
|
||||
|
||||
@@ -46,9 +46,9 @@ repos:
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.5
|
||||
rev: v0.11.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -57,7 +57,7 @@ repos:
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.3
|
||||
rev: v8.24.2
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
|
||||
@@ -103,20 +103,13 @@ When using `miniconda`, install `ffmpeg` in your environment:
|
||||
conda install ffmpeg -c conda-forge
|
||||
```
|
||||
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can:
|
||||
> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using:
|
||||
> ```bash
|
||||
> conda install ffmpeg=7.1.1 -c conda-forge
|
||||
> ```
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
> **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run:
|
||||
`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras:
|
||||
- [aloha](https://github.com/huggingface/gym-aloha)
|
||||
|
||||
@@ -32,11 +32,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import PIL
|
||||
import torch
|
||||
from skimage.metrics import (
|
||||
mean_squared_error,
|
||||
peak_signal_noise_ratio,
|
||||
structural_similarity,
|
||||
)
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
@@ -98,11 +94,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
|
||||
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path,
|
||||
save_dir: Path,
|
||||
frames: torch.Tensor,
|
||||
timestamps: list[float],
|
||||
fps: int,
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
@@ -112,10 +104,7 @@ def save_decoded_frames(
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(
|
||||
imgs_dir / f"frame_{idx:06d}.png",
|
||||
save_dir / f"frame_{idx:06d}_original.png",
|
||||
)
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
@@ -131,11 +120,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
imgs_dataset = hf_dataset.select_columns(img_keys[0])
|
||||
|
||||
for i, item in enumerate(
|
||||
tqdm(
|
||||
imgs_dataset,
|
||||
desc=f"saving {dataset.repo_id} first episode images",
|
||||
leave=False,
|
||||
)
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
@@ -290,9 +275,7 @@ def benchmark_encoding_decoding(
|
||||
random.seed(seed)
|
||||
benchmark_table = []
|
||||
for timestamps_mode in tqdm(
|
||||
decoding_cfg["timestamps_modes"],
|
||||
desc="decodings (timestamps_modes)",
|
||||
leave=False,
|
||||
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
|
||||
):
|
||||
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
|
||||
benchmark_row = benchmark_decoding(
|
||||
|
||||
@@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev libgeos-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install ffmpeg build dependencies. See:
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
|
||||
|
||||
#### Generate protobuf files
|
||||
|
||||
```bash
|
||||
python -m grpc_tools.protoc \
|
||||
-I lerobot/scripts/server \
|
||||
--python_out=lerobot/scripts/server \
|
||||
--grpc_python_out=lerobot/scripts/server \
|
||||
lerobot/scripts/server/hilserl.proto
|
||||
```
|
||||
@@ -32,10 +32,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
|
||||
@@ -22,10 +22,7 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import dataset_to_policy_features
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
@@ -80,24 +77,7 @@ def main():
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to supervise the policy.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# We can then instantiate the dataset with these delta_timestamps configuration.
|
||||
|
||||
@@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly
|
||||
|
||||
## The training script
|
||||
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following:
|
||||
LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following:
|
||||
|
||||
- Initialize/load a configuration for the following steps using.
|
||||
- Instantiates a dataset.
|
||||
@@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig
|
||||
def train(cfg: TrainPipelineConfig):
|
||||
```
|
||||
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option)
|
||||
|
||||
When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.)
|
||||
|
||||
@@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a
|
||||
|
||||
## Specifying values from the CLI
|
||||
|
||||
Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--dataset.repo_id=lerobot/pusht \
|
||||
@@ -60,10 +60,10 @@ python lerobot/scripts/train.py \
|
||||
|
||||
Let's break this down:
|
||||
- To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`.
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py)
|
||||
- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies)
|
||||
- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py)
|
||||
|
||||
Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
@@ -74,7 +74,7 @@ python lerobot/scripts/train.py \
|
||||
> Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`.
|
||||
|
||||
We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task.
|
||||
Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
|
||||
@@ -83,7 +83,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
--id 1
|
||||
```
|
||||
|
||||
Then unplug your first motor and plug the second motor and set its ID to 2.
|
||||
@@ -93,7 +93,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand dynamixel \
|
||||
--model xl330-m288 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
--id 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6.
|
||||
@@ -830,6 +830,11 @@ It contains:
|
||||
- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously.
|
||||
|
||||
Troubleshooting:
|
||||
- On Linux, if you encounter any issue during video encoding with `ffmpeg: unknown encoder libsvtav1`, you can:
|
||||
- install with conda-forge by running `conda install -c conda-forge ffmpeg` (it should be compiled with `libsvtav1`),
|
||||
> **NOTE:** This usually installs `ffmpeg 7.X` for your platform (check the version installed with `ffmpeg -encoders | grep libsvtav1`). If it isn't `ffmpeg 7.X` or lacks `libsvtav1` support, you can explicitly install `ffmpeg 7.X` using: `conda install ffmpeg=7.1.1 -c conda-forge`
|
||||
- or, install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1),
|
||||
- and, make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux).
|
||||
|
||||
At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running:
|
||||
|
||||
@@ -26,10 +26,7 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
LeRobotDatasetMetadata,
|
||||
)
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
|
||||
@@ -54,24 +51,7 @@ def main():
|
||||
# Load the previous action (-0.1), the next action to be executed (0.0),
|
||||
# and 14 future actions with a 0.1 seconds spacing. All these actions will be
|
||||
# used to calculate the loss.
|
||||
"action": [
|
||||
-0.1,
|
||||
0.0,
|
||||
0.1,
|
||||
0.2,
|
||||
0.3,
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
0.7,
|
||||
0.8,
|
||||
0.9,
|
||||
1.0,
|
||||
1.1,
|
||||
1.2,
|
||||
1.3,
|
||||
1.4,
|
||||
],
|
||||
"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
|
||||
4
lerobot/common/cameras/__init__.py
Normal file
4
lerobot/common/cameras/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig
|
||||
|
||||
__all__ = ["Camera", "CameraConfig"]
|
||||
25
lerobot/common/cameras/camera.py
Normal file
25
lerobot/common/cameras/camera.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Camera(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def connect(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_read(self) -> np.ndarray:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
11
lerobot/common/cameras/configs.py
Normal file
11
lerobot/common/cameras/configs.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
4
lerobot/common/cameras/intel/__init__.py
Normal file
4
lerobot/common/cameras/intel/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera_realsense import RealSenseCamera
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
__all__ = ["RealSenseCamera", "RealSenseCameraConfig"]
|
||||
@@ -31,14 +31,15 @@ from threading import Thread
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.robot_utils import (
|
||||
busy_wait,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_realsense import RealSenseCameraConfig
|
||||
|
||||
SERIAL_NUMBER_INDEX = 1
|
||||
|
||||
|
||||
@@ -108,13 +109,11 @@ def save_images_from_cameras(
|
||||
cameras = []
|
||||
for cam_sn in serial_numbers:
|
||||
print(f"{cam_sn=}")
|
||||
config = IntelRealSenseCameraConfig(
|
||||
serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock
|
||||
)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
config = RealSenseCameraConfig(serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
f"RealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -166,11 +165,11 @@ def save_images_from_cameras(
|
||||
camera.disconnect()
|
||||
|
||||
|
||||
class IntelRealSenseCamera:
|
||||
class RealSenseCamera(Camera):
|
||||
"""
|
||||
The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
The RealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras:
|
||||
- is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux,
|
||||
- can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(),
|
||||
- can also be instantiated with the camera's name — if it's unique — using RealSenseCamera.init_from_name(),
|
||||
- depth map can be returned.
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
|
||||
@@ -178,15 +177,15 @@ class IntelRealSenseCamera:
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
|
||||
```
|
||||
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
When an RealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
of the given camera will be used.
|
||||
|
||||
Example of instantiating with a serial number:
|
||||
```python
|
||||
from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig
|
||||
from lerobot.common.robot_devices.cameras.configs import RealSenseCameraConfig
|
||||
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image = camera.read()
|
||||
# when done using the camera, consider disconnecting
|
||||
@@ -195,21 +194,21 @@ class IntelRealSenseCamera:
|
||||
|
||||
Example of instantiating with a name if it's unique:
|
||||
```
|
||||
config = IntelRealSenseCameraConfig(name="Intel RealSense D405")
|
||||
config = RealSenseCameraConfig(name="Intel RealSense D405")
|
||||
```
|
||||
|
||||
Example of changing default fps, width, height and color_mode:
|
||||
```python
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr")
|
||||
# Note: might error out upon `camera.connect()` if these settings are not compatible with the camera
|
||||
```
|
||||
|
||||
Example of returning depth:
|
||||
```python
|
||||
config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
||||
camera = IntelRealSenseCamera(config)
|
||||
config = RealSenseCameraConfig(serial_number=128422271347, use_depth=True)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.connect()
|
||||
color_image, depth_map = camera.read()
|
||||
```
|
||||
@@ -217,7 +216,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: IntelRealSenseCameraConfig,
|
||||
config: RealSenseCameraConfig,
|
||||
):
|
||||
self.config = config
|
||||
if config.name is not None:
|
||||
@@ -282,9 +281,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is already connected."
|
||||
)
|
||||
raise DeviceAlreadyConnectedError(f"RealSenseCamera({self.serial_number}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
@@ -297,11 +294,7 @@ class IntelRealSenseCamera:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.rgb8,
|
||||
self.fps,
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
@@ -309,11 +302,7 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth,
|
||||
self.capture_width,
|
||||
self.capture_height,
|
||||
rs.format.z16,
|
||||
self.fps,
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
@@ -338,7 +327,7 @@ class IntelRealSenseCamera:
|
||||
"To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(f"Can't access RealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_stream = profile.get_stream(rs.stream.color)
|
||||
color_profile = color_stream.as_video_stream_profile()
|
||||
@@ -350,15 +339,15 @@ class IntelRealSenseCamera:
|
||||
if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3):
|
||||
# Using `OSError` since it's a broad that encompasses issues related to device communication
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
f"Can't set {self.fps=} for RealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.capture_width is not None and self.capture_width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
f"Can't set {self.capture_width=} for RealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.capture_height is not None and self.capture_height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
f"Can't set {self.capture_height=} for RealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
@@ -378,8 +367,8 @@ class IntelRealSenseCamera:
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
@@ -394,7 +383,7 @@ class IntelRealSenseCamera:
|
||||
color_frame = frame.get_color_frame()
|
||||
|
||||
if not color_frame:
|
||||
raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(f"Can't capture color image from RealSenseCamera({self.serial_number}).")
|
||||
|
||||
color_image = np.asanyarray(color_frame.get_data())
|
||||
|
||||
@@ -426,7 +415,7 @@ class IntelRealSenseCamera:
|
||||
if self.use_depth:
|
||||
depth_frame = frame.get_depth_frame()
|
||||
if not depth_frame:
|
||||
raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).")
|
||||
raise OSError(f"Can't capture depth image from RealSenseCamera({self.serial_number}).")
|
||||
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
@@ -453,8 +442,8 @@ class IntelRealSenseCamera:
|
||||
def async_read(self):
|
||||
"""Access the latest color image"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is None:
|
||||
@@ -480,8 +469,8 @@ class IntelRealSenseCamera:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
raise DeviceNotConnectedError(
|
||||
f"RealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
@@ -503,14 +492,14 @@ class IntelRealSenseCamera:
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset."
|
||||
description="Save a few frames using `RealSenseCamera` for all cameras connected to the computer, or a selected subset."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--serial-numbers",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.",
|
||||
help="List of serial numbers used to instantiate the `RealSenseCamera`. If not provided, find and use all available camera indices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
@@ -520,13 +509,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
type=str,
|
||||
default=640,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
type=str,
|
||||
default=480,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
@@ -12,67 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass
|
||||
class CameraConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
class OpenCVCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(0, 30, 640, 480)
|
||||
OpenCVCameraConfig(0, 60, 640, 480)
|
||||
OpenCVCameraConfig(0, 90, 640, 480)
|
||||
OpenCVCameraConfig(0, 30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
camera_index: int
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
from ..configs import CameraConfig
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("intelrealsense")
|
||||
@dataclass
|
||||
class IntelRealSenseCameraConfig(CameraConfig):
|
||||
class RealSenseCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 60, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 90, 640, 480)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 1280, 720)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
||||
IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 60, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 90, 640, 480)
|
||||
RealSenseCameraConfig(128422271347, 30, 1280, 720)
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True)
|
||||
RealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90)
|
||||
```
|
||||
"""
|
||||
|
||||
4
lerobot/common/cameras/opencv/__init__.py
Normal file
4
lerobot/common/cameras/opencv/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .camera_opencv import OpenCVCamera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
__all__ = ["OpenCVCamera", "OpenCVCameraConfig"]
|
||||
@@ -24,19 +24,20 @@ import shutil
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.robot_utils import (
|
||||
busy_wait,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
from ..camera import Camera
|
||||
from .configuration_opencv import OpenCVCameraConfig
|
||||
|
||||
# The maximum opencv device index depends on your operating system. For instance,
|
||||
# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case
|
||||
# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23.
|
||||
@@ -45,12 +46,12 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
possible_ports = [str(port) for port in Path("/dev").glob("video*")]
|
||||
ports = _find_cameras(possible_ports, mock=mock)
|
||||
ports = _find_cameras(possible_ports)
|
||||
for port in ports:
|
||||
cameras.append(
|
||||
{
|
||||
@@ -64,7 +65,7 @@ def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX
|
||||
f"scanning all indices from 0 to {MAX_OPENCV_INDEX}"
|
||||
)
|
||||
possible_indices = range(max_index_search_range)
|
||||
indices = _find_cameras(possible_indices, mock=mock)
|
||||
indices = _find_cameras(possible_indices)
|
||||
for index in indices:
|
||||
cameras.append(
|
||||
{
|
||||
@@ -76,14 +77,7 @@ def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX
|
||||
return cameras
|
||||
|
||||
|
||||
def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
def _find_cameras(possible_camera_ids: list[int | str], raise_when_empty=False) -> list[int | str]:
|
||||
camera_ids = []
|
||||
for camera_idx in possible_camera_ids:
|
||||
camera = cv2.VideoCapture(camera_idx)
|
||||
@@ -127,20 +121,19 @@ def save_images_from_cameras(
|
||||
width=None,
|
||||
height=None,
|
||||
record_time_s=2,
|
||||
mock=False,
|
||||
):
|
||||
"""
|
||||
Initializes all the cameras and saves images to the directory. Useful to visually identify the camera
|
||||
associated to a given camera index.
|
||||
"""
|
||||
if camera_ids is None or len(camera_ids) == 0:
|
||||
camera_infos = find_cameras(mock=mock)
|
||||
camera_infos = find_cameras()
|
||||
camera_ids = [cam["index"] for cam in camera_infos]
|
||||
|
||||
print("Connecting cameras")
|
||||
cameras = []
|
||||
for cam_idx in camera_ids:
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock)
|
||||
config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height)
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
@@ -190,7 +183,7 @@ def save_images_from_cameras(
|
||||
print(f"Images have been saved to {images_dir}")
|
||||
|
||||
|
||||
class OpenCVCamera:
|
||||
class OpenCVCamera(Camera):
|
||||
"""
|
||||
The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate
|
||||
with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html).
|
||||
@@ -259,7 +252,6 @@ class OpenCVCamera:
|
||||
self.fps = config.fps
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
|
||||
self.camera = None
|
||||
self.is_connected = False
|
||||
@@ -268,11 +260,6 @@ class OpenCVCamera:
|
||||
self.color_image = None
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -283,16 +270,11 @@ class OpenCVCamera:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
raise DeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
# Use 1 thread to avoid blocking the main thread. Especially useful during data collection
|
||||
# when other threads are used to save the images.
|
||||
cv2.setNumThreads(1)
|
||||
|
||||
backend = (
|
||||
cv2.CAP_V4L2
|
||||
@@ -305,17 +287,11 @@ class OpenCVCamera:
|
||||
)
|
||||
|
||||
camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index
|
||||
# First create a temporary camera trying to access `camera_index`,
|
||||
# and verify it is a valid camera by calling `isOpened`.
|
||||
tmp_camera = cv2.VideoCapture(camera_idx, backend)
|
||||
is_camera_open = tmp_camera.isOpened()
|
||||
# Release camera to make it accessible for `find_camera_indices`
|
||||
tmp_camera.release()
|
||||
del tmp_camera
|
||||
|
||||
# If the camera doesn't work, display the camera indices corresponding to
|
||||
# valid cameras.
|
||||
if not is_camera_open:
|
||||
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
||||
if not self.camera.isOpened():
|
||||
self.camera.release() # Release the failed attempt
|
||||
# Verify that the provided `camera_index` is valid before printing the traceback
|
||||
cameras_info = find_cameras()
|
||||
available_cam_ids = [cam["index"] for cam in cameras_info]
|
||||
@@ -327,11 +303,6 @@ class OpenCVCamera:
|
||||
|
||||
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
|
||||
|
||||
# Secondly, create the camera that will be used downstream.
|
||||
# Note: For some unknown reason, calling `isOpened` blocks the camera which then
|
||||
# needs to be re-created.
|
||||
self.camera = cv2.VideoCapture(camera_idx, backend)
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.capture_width is not None:
|
||||
@@ -375,7 +346,7 @@ class OpenCVCamera:
|
||||
If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
@@ -397,11 +368,6 @@ class OpenCVCamera:
|
||||
# However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks,
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
@@ -432,13 +398,13 @@ class OpenCVCamera:
|
||||
|
||||
def async_read(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
if self.thread is None:
|
||||
self.stop_event = threading.Event()
|
||||
self.thread = Thread(target=self.read_loop, args=())
|
||||
self.thread = threading.Thread(target=self.read_loop, args=())
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
@@ -454,7 +420,7 @@ class OpenCVCamera:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first."
|
||||
)
|
||||
|
||||
@@ -492,13 +458,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the width for all cameras. If not provided, use the default width of each camera.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
type=str,
|
||||
default=None,
|
||||
help="Set the height for all cameras. If not provided, use the default height of each camera.",
|
||||
)
|
||||
37
lerobot/common/cameras/opencv/configuration_opencv.py
Normal file
37
lerobot/common/cameras/opencv/configuration_opencv.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("opencv")
|
||||
@dataclass
|
||||
class OpenCVCameraConfig(CameraConfig):
|
||||
"""
|
||||
Example of tested options for Intel Real Sense D405:
|
||||
|
||||
```python
|
||||
OpenCVCameraConfig(0, 30, 640, 480)
|
||||
OpenCVCameraConfig(0, 60, 640, 480)
|
||||
OpenCVCameraConfig(0, 90, 640, 480)
|
||||
OpenCVCameraConfig(0, 30, 1280, 720)
|
||||
```
|
||||
"""
|
||||
|
||||
camera_index: int
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
21
lerobot/common/cameras/utils.py
Normal file
21
lerobot/common/cameras/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from .camera import Camera
|
||||
from .configs import CameraConfig
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
|
||||
cameras = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
if cfg.type == "opencv":
|
||||
from .opencv import OpenCVCamera
|
||||
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from .intel.camera_realsense import RealSenseCamera
|
||||
|
||||
cameras[key] = RealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
@@ -17,12 +17,15 @@ from pathlib import Path
|
||||
|
||||
from huggingface_hub.constants import HF_HOME
|
||||
|
||||
OBS_ENV = "observation.environment_state"
|
||||
OBS_ROBOT = "observation.state"
|
||||
OBS_ENV_STATE = "observation.environment_state"
|
||||
OBS_STATE = "observation.state"
|
||||
OBS_IMAGE = "observation.image"
|
||||
OBS_IMAGES = "observation.images"
|
||||
ACTION = "action"
|
||||
|
||||
ROBOTS = "robots"
|
||||
TELEOPERATORS = "teleoperators"
|
||||
|
||||
# files & directories
|
||||
CHECKPOINTS_DIR = "checkpoints"
|
||||
LAST_CHECKPOINT_LINK = "last"
|
||||
@@ -34,12 +37,16 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||
SCHEDULER_STATE = "scheduler_state.json"
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
if "LEROBOT_HOME" in os.environ:
|
||||
raise ValueError(
|
||||
f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n"
|
||||
"'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead."
|
||||
)
|
||||
|
||||
# cache dir
|
||||
default_cache_path = Path(HF_HOME) / "lerobot"
|
||||
HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser()
|
||||
|
||||
# calibration dir
|
||||
default_calibration_path = HF_LEROBOT_HOME / ".calibration"
|
||||
HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser()
|
||||
|
||||
@@ -19,10 +19,7 @@ from lerobot.common.datasets.utils import load_image_as_numpy
|
||||
|
||||
|
||||
def estimate_num_samples(
|
||||
dataset_len: int,
|
||||
min_num_samples: int = 100,
|
||||
max_num_samples: int = 10_000,
|
||||
power: float = 0.75,
|
||||
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
|
||||
) -> int:
|
||||
"""Heuristic to estimate the number of samples based on dataset size.
|
||||
The power controls the sample growth relative to dataset size.
|
||||
@@ -126,9 +123,7 @@ def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
|
||||
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
|
||||
|
||||
|
||||
def aggregate_feature_stats(
|
||||
stats_ft_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregates stats for a single feature."""
|
||||
means = np.stack([s["mean"] for s in stats_ft_list])
|
||||
variances = np.stack([s["std"] ** 2 for s in stats_ft_list])
|
||||
@@ -157,9 +152,7 @@ def aggregate_feature_stats(
|
||||
}
|
||||
|
||||
|
||||
def aggregate_stats(
|
||||
stats_list: list[dict[str, dict]],
|
||||
) -> dict[str, dict[str, np.ndarray]]:
|
||||
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
|
||||
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.
|
||||
|
||||
The final stats will have the union of all data keys from each of the stats dicts.
|
||||
|
||||
@@ -72,7 +72,7 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robots.utils import Robot
|
||||
|
||||
CODEBASE_VERSION = "v2.1"
|
||||
|
||||
@@ -318,7 +318,7 @@ class LeRobotDatasetMetadata:
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
@@ -821,9 +821,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"],
|
||||
image_key=key,
|
||||
frame_index=frame_index,
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -869,10 +867,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key, ft in self.features.items():
|
||||
# index, episode_index, task_index are already processed above, and image and video
|
||||
# are processed separately by storing image path and frame info as meta data
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in [
|
||||
"image",
|
||||
"video",
|
||||
]:
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
||||
|
||||
|
||||
@@ -154,32 +154,14 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()},
|
||||
# Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied
|
||||
# with real data rather than the dummy initialization.
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {
|
||||
"dtype": np.dtype("?"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {
|
||||
"dtype": np.dtype("int64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {
|
||||
"dtype": np.dtype("float64"),
|
||||
"shape": (buffer_capacity,),
|
||||
},
|
||||
OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)},
|
||||
OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)},
|
||||
}
|
||||
for k, v in data_spec.items():
|
||||
complete_data_spec[k] = {
|
||||
"dtype": v["dtype"],
|
||||
"shape": (buffer_capacity, *v["shape"]),
|
||||
}
|
||||
complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])}
|
||||
return complete_data_spec
|
||||
|
||||
def add_data(self, data: dict[str, np.ndarray]):
|
||||
|
||||
@@ -77,9 +77,7 @@ def check_repo_id(repo_id: str) -> None:
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(
|
||||
hf_dataset: datasets.Dataset,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
|
||||
@@ -43,10 +43,7 @@ class EpisodeAwareSampler:
|
||||
):
|
||||
if episode_indices_to_use is None or episode_idx in episode_indices_to_use:
|
||||
indices.extend(
|
||||
range(
|
||||
start_index.item() + drop_n_first_frames,
|
||||
end_index.item() - drop_n_last_frames,
|
||||
)
|
||||
range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames)
|
||||
)
|
||||
|
||||
self.indices = indices
|
||||
|
||||
@@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import (
|
||||
BackwardCompatibilityError,
|
||||
ForwardCompatibilityError,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
|
||||
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
|
||||
|
||||
@@ -225,10 +225,7 @@ def load_episodes(local_dir: Path) -> dict:
|
||||
def write_episode_stats(episode_index: int, episode_stats: dict, local_dir: Path):
|
||||
# We wrap episode_stats in a dictionary since `episode_stats["episode_index"]`
|
||||
# is a dictionary of stats and not an integer.
|
||||
episode_stats = {
|
||||
"episode_index": episode_index,
|
||||
"stats": serialize_dict(episode_stats),
|
||||
}
|
||||
episode_stats = {"episode_index": episode_index, "stats": serialize_dict(episode_stats)}
|
||||
append_jsonlines(episode_stats, local_dir / EPISODES_STATS_PATH)
|
||||
|
||||
|
||||
@@ -412,7 +409,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
|
||||
names = ft["names"]
|
||||
# Backward compatibility for "channel" which is an error introduced in LeRobotDataset v2.0 for ported datasets.
|
||||
if names is not None and names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
@@ -543,10 +540,7 @@ def check_timestamps_sync(
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
|
||||
@@ -27,7 +27,7 @@ from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset
|
||||
from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig
|
||||
from lerobot.common.robots.aloha.configuration_aloha import AlohaRobotConfig
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
|
||||
@@ -118,10 +118,7 @@ DATASETS = {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {
|
||||
"single_task": "Pick up the candy and unwrap it.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -170,22 +167,13 @@ DATASETS = {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {
|
||||
"single_task": "Slide open the ziploc bag.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_scripted": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
@@ -206,19 +194,10 @@ DATASETS = {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht_image": {
|
||||
"single_task": "Push the T-shaped block onto the T-shaped target.",
|
||||
**PUSHT_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {
|
||||
"single_task": "Put the object into the bin.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
@@ -228,31 +207,13 @@ DATASETS = {
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_replay_image": {
|
||||
"single_task": "Pick up the cube and lift it.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_replay_image": {
|
||||
"single_task": "Push the cube onto the target.",
|
||||
**XARM_INFO,
|
||||
},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
|
||||
@@ -141,8 +141,8 @@ from lerobot.common.datasets.video_utils import (
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot_config
|
||||
from lerobot.common.robots import RobotConfig
|
||||
from lerobot.common.robots.utils import make_robot_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
@@ -379,12 +379,7 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(
|
||||
["git", "rm", "--cached", *files],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
)
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
@@ -407,17 +402,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
[
|
||||
"git",
|
||||
"clone",
|
||||
"--branch",
|
||||
branch,
|
||||
"--single-branch",
|
||||
"--depth",
|
||||
"1",
|
||||
repo_url,
|
||||
str(work_dir),
|
||||
],
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
@@ -425,11 +410,7 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"],
|
||||
cwd=work_dir,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
@@ -443,11 +424,7 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
revision=branch,
|
||||
allow_patterns=video_files,
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
@@ -474,11 +451,7 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
revision=v1,
|
||||
local_dir=v1x_dir,
|
||||
ignore_patterns="videos*/",
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
@@ -536,21 +509,12 @@ def convert_dataset(
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF,
|
||||
repo_type="dataset",
|
||||
local_dir=local_dir,
|
||||
filename=".gitattributes",
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id,
|
||||
video_keys,
|
||||
total_episodes,
|
||||
total_chunks,
|
||||
Path(tmp_video_dir),
|
||||
clean_gitattr,
|
||||
branch,
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
@@ -579,11 +543,7 @@ def convert_dataset(
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": tasks_by_episodes[ep_idx],
|
||||
"length": episode_lengths[ep_idx],
|
||||
}
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
@@ -612,12 +572,7 @@ def convert_dataset(
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta_data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -37,16 +37,8 @@ import logging
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_STATS_PATH,
|
||||
STATS_PATH,
|
||||
load_stats,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.common.datasets.v21.convert_stats import (
|
||||
check_aggregate_stats,
|
||||
convert_stats,
|
||||
)
|
||||
from lerobot.common.datasets.utils import EPISODES_STATS_PATH, STATS_PATH, load_stats, write_info
|
||||
from lerobot.common.datasets.v21.convert_stats import check_aggregate_stats, convert_stats
|
||||
|
||||
V20 = "v2.0"
|
||||
V21 = "v2.1"
|
||||
@@ -87,16 +79,10 @@ def convert_dataset(
|
||||
|
||||
hub_api = HfApi()
|
||||
if hub_api.file_exists(
|
||||
repo_id=dataset.repo_id,
|
||||
filename=STATS_PATH,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
repo_id=dataset.repo_id, filename=STATS_PATH, revision=branch, repo_type="dataset"
|
||||
):
|
||||
hub_api.delete_file(
|
||||
path_in_repo=STATS_PATH,
|
||||
repo_id=dataset.repo_id,
|
||||
revision=branch,
|
||||
repo_type="dataset",
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
@@ -17,11 +17,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.compute_stats import (
|
||||
aggregate_stats,
|
||||
get_feature_stats,
|
||||
sample_indices,
|
||||
)
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import write_episode_stats
|
||||
|
||||
@@ -99,9 +95,5 @@ def check_aggregate_stats(
|
||||
if key in reference_stats and stat in reference_stats[key]:
|
||||
err_msg = f"feature='{key}' stats='{stat}'"
|
||||
np.testing.assert_allclose(
|
||||
val,
|
||||
reference_stats[key][stat],
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
err_msg=err_msg,
|
||||
val, reference_stats[key][stat], rtol=rtol, atol=atol, err_msg=err_msg
|
||||
)
|
||||
|
||||
@@ -14,12 +14,10 @@
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT
|
||||
from lerobot.common.robot_devices.robots.configs import RobotConfig
|
||||
from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
|
||||
|
||||
@@ -55,7 +53,7 @@ class AlohaEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"agent_pos": OBS_STATE,
|
||||
"top": f"{OBS_IMAGE}.top",
|
||||
"pixels/top": f"{OBS_IMAGES}.top",
|
||||
}
|
||||
@@ -96,8 +94,8 @@ class PushtEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"environment_state": OBS_ENV,
|
||||
"agent_pos": OBS_STATE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
@@ -138,7 +136,7 @@ class XarmEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"agent_pos": OBS_ROBOT,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels": OBS_IMAGE,
|
||||
}
|
||||
)
|
||||
@@ -156,135 +154,3 @@ class XarmEnv(EnvConfig):
|
||||
"visualization_height": self.visualization_height,
|
||||
"max_episode_steps": self.episode_length,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoRecordConfig:
|
||||
"""Configuration for video recording in ManiSkill environments."""
|
||||
|
||||
enabled: bool = False
|
||||
record_dir: str = "videos"
|
||||
trajectory_name: str = "trajectory"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
joint_masking_action_space: list[bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EEActionSpaceConfig:
|
||||
"""Configuration parameters for end-effector action space."""
|
||||
|
||||
x_step_size: float
|
||||
y_step_size: float
|
||||
z_step_size: float
|
||||
bounds: Dict[str, Any] # Contains 'min' and 'max' keys with position bounds
|
||||
use_gamepad: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrapperConfig:
|
||||
"""Configuration for environment wrappers."""
|
||||
|
||||
display_cameras: bool = False
|
||||
use_relative_joint_positions: bool = True
|
||||
add_joint_velocity_to_observation: bool = False
|
||||
add_ee_pose_to_observation: bool = False
|
||||
crop_params_dict: Optional[Dict[str, Tuple[int, int, int, int]]] = None
|
||||
resize_size: Optional[Tuple[int, int]] = None
|
||||
control_time_s: float = 20.0
|
||||
fixed_reset_joint_positions: Optional[Any] = None
|
||||
reset_time_s: float = 5.0
|
||||
joint_masking_action_space: Optional[Any] = None
|
||||
ee_action_space_params: Optional[EEActionSpaceConfig] = None
|
||||
use_gripper: bool = False
|
||||
gripper_quantization_threshold: float | None = 0.8
|
||||
gripper_penalty: float = 0.0
|
||||
gripper_penalty_in_reward: bool = False
|
||||
open_gripper_on_reset: bool = False
|
||||
|
||||
|
||||
@EnvConfig.register_subclass(name="gym_manipulator")
|
||||
@dataclass
|
||||
class HILSerlRobotEnvConfig(EnvConfig):
|
||||
"""Configuration for the HILSerlRobotEnv environment."""
|
||||
|
||||
robot: Optional[RobotConfig] = None
|
||||
wrapper: Optional[EnvWrapperConfig] = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
repo_id: Optional[str] = None
|
||||
dataset_root: Optional[str] = None
|
||||
task: str = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
push_to_hub: bool = True
|
||||
pretrained_policy_name_or_path: Optional[str] = None
|
||||
reward_classifier: dict[str, str | None] = field(
|
||||
default_factory=lambda: {
|
||||
"pretrained_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("maniskill_push")
|
||||
@dataclass
|
||||
class ManiskillEnvConfig(EnvConfig):
|
||||
"""Configuration for the ManiSkill environment."""
|
||||
|
||||
name: str = "maniskill/pushcube"
|
||||
task: str = "PushCube-v1"
|
||||
image_size: int = 64
|
||||
control_mode: str = "pd_ee_delta_pose"
|
||||
state_dim: int = 25
|
||||
action_dim: int = 7
|
||||
fps: int = 200
|
||||
episode_length: int = 50
|
||||
obs_type: str = "rgb"
|
||||
render_mode: str = "rgb_array"
|
||||
render_size: int = 64
|
||||
device: str = "cuda"
|
||||
robot: str = "so100" # This is a hack to make the robot config work
|
||||
video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig)
|
||||
wrapper: WrapperConfig = field(default_factory=WrapperConfig)
|
||||
mock_gripper: bool = False
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
"observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 64, 64)),
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(25,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"action": ACTION,
|
||||
"observation.image": OBS_IMAGE,
|
||||
"observation.state": OBS_ROBOT,
|
||||
}
|
||||
)
|
||||
reward_classifier: dict[str, str | None] = field(
|
||||
default_factory=lambda: {
|
||||
"pretrained_path": None,
|
||||
"config_path": None,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"obs_type": self.obs_type,
|
||||
"render_mode": self.render_mode,
|
||||
"max_episode_steps": self.episode_length,
|
||||
"control_mode": self.control_mode,
|
||||
"sensor_configs": {"width": self.image_size, "height": self.image_size},
|
||||
"num_envs": 1,
|
||||
}
|
||||
|
||||
@@ -37,35 +37,29 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
for key, img in observations.items():
|
||||
if "images" not in key:
|
||||
continue
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
if not torch.is_tensor(img):
|
||||
for imgkey, img in imgs.items():
|
||||
# TODO(aliberts, rcadene): use transforms.ToTensor()?
|
||||
img = torch.from_numpy(img)
|
||||
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[key] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
return_observations[imgkey] = img
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -74,8 +68,7 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = observations["observation.state"].float()
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
@@ -93,7 +86,7 @@ def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]:
|
||||
else:
|
||||
feature = ft
|
||||
|
||||
policy_key = env_cfg.features_map.get(key, key)
|
||||
policy_key = env_cfg.features_map[key]
|
||||
policy_features[policy_key] = feature
|
||||
|
||||
return policy_features
|
||||
@@ -108,9 +101,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("once", UserWarning) # Apply filter only in this function
|
||||
|
||||
if not (
|
||||
hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")
|
||||
):
|
||||
if not (hasattr(env.envs[0], "task_description") and hasattr(env.envs[0], "task")):
|
||||
warnings.warn(
|
||||
"The environment does not have 'task_description' and 'task'. Some policies require these features.",
|
||||
UserWarning,
|
||||
@@ -124,9 +115,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None:
|
||||
)
|
||||
|
||||
|
||||
def add_envs_task(
|
||||
env: gym.vector.VectorEnv, observation: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Adds task feature to the observation dict with respect to the first environment attribute."""
|
||||
if hasattr(env.envs[0], "task_description"):
|
||||
observation["task"] = env.call("task_description")
|
||||
|
||||
17
lerobot/common/errors.py
Normal file
17
lerobot/common/errors.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class DeviceNotConnectedError(ConnectionError):
|
||||
"""Exception raised when the device is not connected."""
|
||||
|
||||
def __init__(self, message="This device is not connected. Try calling `connect()` first."):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeviceAlreadyConnectedError(ConnectionError):
|
||||
"""Exception raised when the device is already connected."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message="This device is already connected. Try not calling `connect()` twice.",
|
||||
):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
1
lerobot/common/motors/__init__.py
Normal file
1
lerobot/common/motors/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus
|
||||
3
lerobot/common/motors/dynamixel/__init__.py
Normal file
3
lerobot/common/motors/dynamixel/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
|
||||
from .dynamixel_calibration import run_arm_calibration
|
||||
from .tables import *
|
||||
206
lerobot/common/motors/dynamixel/dynamixel.py
Normal file
206
lerobot/common/motors/dynamixel/dynamixel.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO(aliberts): Should we implement FastSyncRead/Write?
|
||||
# https://github.com/ROBOTIS-GIT/DynamixelSDK/pull/643
|
||||
# https://github.com/ROBOTIS-GIT/DynamixelSDK/releases/tag/3.8.2
|
||||
# https://emanual.robotis.com/docs/en/dxl/protocol2/#fast-sync-read-0x8a
|
||||
# -> Need to check compatibility across models
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
|
||||
from lerobot.common.utils.encoding_utils import decode_twos_complement, encode_twos_complement
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
||||
from .tables import (
|
||||
AVAILABLE_BAUDRATES,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
MODEL_CONTROL_TABLE,
|
||||
MODEL_ENCODING_TABLE,
|
||||
MODEL_NUMBER_TABLE,
|
||||
MODEL_RESOLUTION,
|
||||
)
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUDRATE = 1_000_000
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperatingMode(Enum):
|
||||
# DYNAMIXEL only controls current(torque) regardless of speed and position. This mode is ideal for a
|
||||
# gripper or a system that only uses current(torque) control or a system that has additional
|
||||
# velocity/position controllers.
|
||||
CURRENT = 0
|
||||
|
||||
# This mode controls velocity. This mode is identical to the Wheel Mode(endless) from existing DYNAMIXEL.
|
||||
# This mode is ideal for wheel-type robots.
|
||||
VELOCITY = 1
|
||||
|
||||
# This mode controls position. This mode is identical to the Joint Mode from existing DYNAMIXEL. Operating
|
||||
# position range is limited by the Max Position Limit(48) and the Min Position Limit(52). This mode is
|
||||
# ideal for articulated robots that each joint rotates less than 360 degrees.
|
||||
POSITION = 3
|
||||
|
||||
# This mode controls position. This mode is identical to the Multi-turn Position Control from existing
|
||||
# DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or
|
||||
# conveyer systems or a system that requires an additional reduction gear. Note that Max Position
|
||||
# Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode.
|
||||
EXTENDED_POSITION = 4
|
||||
|
||||
# This mode controls both position and current(torque). Up to 512 turns are supported (-256[rev] ~
|
||||
# 256[rev]). This mode is ideal for a system that requires both position and current control such as
|
||||
# articulated robots or grippers.
|
||||
CURRENT_POSITION = 5
|
||||
|
||||
# This mode directly controls PWM output. (Voltage Control Mode)
|
||||
PWM = 16
|
||||
|
||||
|
||||
class DriveMode(Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class TorqueMode(Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
if length == 1:
|
||||
data = [value]
|
||||
elif length == 2:
|
||||
data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)]
|
||||
elif length == 4:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
class DynamixelMotorsBus(MotorsBus):
|
||||
"""
|
||||
The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with
|
||||
the motors. For more info, see the Dynamixel SDK Documentation:
|
||||
https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20
|
||||
"""
|
||||
|
||||
available_baudrates = deepcopy(AVAILABLE_BAUDRATES)
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
||||
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
|
||||
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
|
||||
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
super().__init__(port, motors, calibration)
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
self.port_handler = dxl.PortHandler(self.port)
|
||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||
self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
self._comm_success = dxl.COMM_SUCCESS
|
||||
self._no_error = 0x00
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||
pass
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
# By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||
for motor in self.motors:
|
||||
self.write("Return_Delay_Time", motor, 0)
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
|
||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
for id_ in ids_values:
|
||||
model = self._id_to_model(id_)
|
||||
encoding_table = self.model_encoding_table.get(model)
|
||||
if encoding_table and data_name in encoding_table:
|
||||
n_bytes = encoding_table[data_name]
|
||||
ids_values[id_] = encode_twos_complement(ids_values[id_], n_bytes)
|
||||
|
||||
return ids_values
|
||||
|
||||
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
for id_ in ids_values:
|
||||
model = self._id_to_model(id_)
|
||||
encoding_table = self.model_encoding_table.get(model)
|
||||
if encoding_table and data_name in encoding_table:
|
||||
n_bytes = encoding_table[data_name]
|
||||
ids_values[id_] = decode_twos_complement(ids_values[id_], n_bytes)
|
||||
|
||||
return ids_values
|
||||
|
||||
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
||||
"""
|
||||
On Dynamixel Motors:
|
||||
Present_Position = Actual_Position + Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
half_turn_homings[motor] = int(max_res / 2) - pos
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||
return _split_into_byte_chunks(value, length)
|
||||
|
||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||
for n_try in range(1 + num_retry):
|
||||
data_list, comm = self.packet_handler.broadcastPing(self.port_handler)
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||
logger.debug(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
return
|
||||
|
||||
return {id_: data[0] for id_, data in data_list.items()}
|
||||
@@ -17,12 +17,9 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.motors.dynamixel import (
|
||||
CalibrationMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
from ..motors_bus import MotorNormMode, MotorsBus
|
||||
from .dynamixel import TorqueMode
|
||||
from .tables import MODEL_RESOLUTION
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
@@ -49,6 +46,17 @@ def apply_drive_mode(position, drive_mode):
|
||||
return position
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
"""
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
|
||||
|
||||
def compute_nearest_rounded_position(position, models):
|
||||
delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models)
|
||||
nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn
|
||||
@@ -89,11 +97,11 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
||||
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
||||
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
||||
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.models)
|
||||
|
||||
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
||||
zero_pos = arm.read("Present_Position")
|
||||
zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models)
|
||||
zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.models)
|
||||
homing_offset = zero_target_pos - zero_nearest_pos
|
||||
|
||||
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
||||
@@ -107,7 +115,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
@@ -116,7 +124,7 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models)
|
||||
rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.models)
|
||||
homing_offset = rotated_target_pos - rotated_nearest_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
@@ -125,13 +133,13 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
print()
|
||||
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names)
|
||||
calib_mode = [MotorNormMode.DEGREE.name] * len(arm.names)
|
||||
|
||||
# TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml?
|
||||
if robot_type in ["aloha"] and "gripper" in arm.motor_names:
|
||||
if robot_type in ["aloha"] and "gripper" in arm.names:
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
calib_idx = arm.motor_names.index("gripper")
|
||||
calib_mode[calib_idx] = CalibrationMode.LINEAR.name
|
||||
calib_idx = arm.names.index("gripper")
|
||||
calib_mode[calib_idx] = MotorNormMode.LINEAR.name
|
||||
|
||||
calib_data = {
|
||||
"homing_offset": homing_offset.tolist(),
|
||||
@@ -139,6 +147,6 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type
|
||||
"start_pos": zero_pos.tolist(),
|
||||
"end_pos": rotated_pos.tolist(),
|
||||
"calib_mode": calib_mode,
|
||||
"motor_names": arm.motor_names,
|
||||
"motor_names": arm.names,
|
||||
}
|
||||
return calib_data
|
||||
162
lerobot/common/motors/dynamixel/tables.py
Normal file
162
lerobot/common/motors/dynamixel/tables.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# {data_name: (address, size_byte)}
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table
|
||||
X_SERIES_CONTROL_TABLE = {
|
||||
"Model_Number": (0, 2),
|
||||
"Model_Information": (2, 4),
|
||||
"Firmware_Version": (6, 1),
|
||||
"ID": (7, 1),
|
||||
"Baud_Rate": (8, 1),
|
||||
"Return_Delay_Time": (9, 1),
|
||||
"Drive_Mode": (10, 1),
|
||||
"Operating_Mode": (11, 1),
|
||||
"Secondary_ID": (12, 1),
|
||||
"Protocol_Type": (13, 1),
|
||||
"Homing_Offset": (20, 4),
|
||||
"Moving_Threshold": (24, 4),
|
||||
"Temperature_Limit": (31, 1),
|
||||
"Max_Voltage_Limit": (32, 2),
|
||||
"Min_Voltage_Limit": (34, 2),
|
||||
"PWM_Limit": (36, 2),
|
||||
"Current_Limit": (38, 2),
|
||||
"Acceleration_Limit": (40, 4),
|
||||
"Velocity_Limit": (44, 4),
|
||||
"Max_Position_Limit": (48, 4),
|
||||
"Min_Position_Limit": (52, 4),
|
||||
"Shutdown": (63, 1),
|
||||
"Torque_Enable": (64, 1),
|
||||
"LED": (65, 1),
|
||||
"Status_Return_Level": (68, 1),
|
||||
"Registered_Instruction": (69, 1),
|
||||
"Hardware_Error_Status": (70, 1),
|
||||
"Velocity_I_Gain": (76, 2),
|
||||
"Velocity_P_Gain": (78, 2),
|
||||
"Position_D_Gain": (80, 2),
|
||||
"Position_I_Gain": (82, 2),
|
||||
"Position_P_Gain": (84, 2),
|
||||
"Feedforward_2nd_Gain": (88, 2),
|
||||
"Feedforward_1st_Gain": (90, 2),
|
||||
"Bus_Watchdog": (98, 1),
|
||||
"Goal_PWM": (100, 2),
|
||||
"Goal_Current": (102, 2),
|
||||
"Goal_Velocity": (104, 4),
|
||||
"Profile_Acceleration": (108, 4),
|
||||
"Profile_Velocity": (112, 4),
|
||||
"Goal_Position": (116, 4),
|
||||
"Realtime_Tick": (120, 2),
|
||||
"Moving": (122, 1),
|
||||
"Moving_Status": (123, 1),
|
||||
"Present_PWM": (124, 2),
|
||||
"Present_Current": (126, 2),
|
||||
"Present_Velocity": (128, 4),
|
||||
"Present_Position": (132, 4),
|
||||
"Velocity_Trajectory": (136, 4),
|
||||
"Position_Trajectory": (140, 4),
|
||||
"Present_Input_Voltage": (144, 2),
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
}
|
||||
|
||||
# {data_name: size_byte}
|
||||
X_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": X_SERIES_CONTROL_TABLE["Homing_Offset"][1],
|
||||
"Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1],
|
||||
"Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1],
|
||||
"Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1],
|
||||
"Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1],
|
||||
"Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1],
|
||||
"Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1],
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
"x_series": X_SERIES_ENCODINGS_TABLE,
|
||||
"xl330-m077": X_SERIES_ENCODINGS_TABLE,
|
||||
"xl330-m288": X_SERIES_ENCODINGS_TABLE,
|
||||
"xl430-w250": X_SERIES_ENCODINGS_TABLE,
|
||||
"xm430-w350": X_SERIES_ENCODINGS_TABLE,
|
||||
"xm540-w270": X_SERIES_ENCODINGS_TABLE,
|
||||
"xc430-w150": X_SERIES_ENCODINGS_TABLE,
|
||||
}
|
||||
|
||||
# {model: model_resolution}
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#specifications
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
"xc430-w150": 4096,
|
||||
}
|
||||
|
||||
# {model: model_number}
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table-of-eeprom-area
|
||||
MODEL_NUMBER_TABLE = {
|
||||
"xl330-m077": 1190,
|
||||
"xl330-m288": 1200,
|
||||
"xl430-w250": 1060,
|
||||
"xm430-w350": 1020,
|
||||
"xm540-w270": 1120,
|
||||
"xc430-w150": 1070,
|
||||
}
|
||||
|
||||
# {model: available_operating_modes}
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#operating-mode11
|
||||
MODEL_OPERATING_MODES = {
|
||||
"xl330-m077": [0, 1, 3, 4, 5, 16],
|
||||
"xl330-m288": [0, 1, 3, 4, 5, 16],
|
||||
"xl430-w250": [1, 3, 4, 16],
|
||||
"xm430-w350": [0, 1, 3, 4, 5, 16],
|
||||
"xm540-w270": [0, 1, 3, 4, 5, 16],
|
||||
"xc430-w150": [1, 3, 4, 16],
|
||||
}
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"x_series": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m077": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m288": X_SERIES_CONTROL_TABLE,
|
||||
"xl430-w250": X_SERIES_CONTROL_TABLE,
|
||||
"xm430-w350": X_SERIES_CONTROL_TABLE,
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
"xc430-w150": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
AVAILABLE_BAUDRATES = [
|
||||
9_600,
|
||||
19_200,
|
||||
38_400,
|
||||
57_600,
|
||||
115_200,
|
||||
230_400,
|
||||
460_800,
|
||||
500_000,
|
||||
576_000,
|
||||
921_600,
|
||||
1_000_000,
|
||||
1_152_000,
|
||||
2_000_000,
|
||||
2_500_000,
|
||||
3_000_000,
|
||||
3_500_000,
|
||||
4_000_000,
|
||||
]
|
||||
2
lerobot/common/motors/feetech/__init__.py
Normal file
2
lerobot/common/motors/feetech/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
|
||||
from .tables import *
|
||||
367
lerobot/common/motors/feetech/feetech.py
Normal file
367
lerobot/common/motors/feetech/feetech.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude
|
||||
|
||||
from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value
|
||||
from .tables import (
|
||||
FIRMWARE_MAJOR_VERSION,
|
||||
FIRMWARE_MINOR_VERSION,
|
||||
MODEL_BAUDRATE_TABLE,
|
||||
MODEL_CONTROL_TABLE,
|
||||
MODEL_ENCODING_TABLE,
|
||||
MODEL_NUMBER,
|
||||
MODEL_NUMBER_TABLE,
|
||||
MODEL_PROTOCOL,
|
||||
MODEL_RESOLUTION,
|
||||
SCAN_BAUDRATES,
|
||||
)
|
||||
|
||||
DEFAULT_PROTOCOL_VERSION = 0
|
||||
BAUDRATE = 1_000_000
|
||||
DEFAULT_TIMEOUT_MS = 1000
|
||||
|
||||
NORMALIZED_DATA = ["Goal_Position", "Present_Position"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OperatingMode(Enum):
|
||||
# position servo mode
|
||||
POSITION = 0
|
||||
# The motor is in constant speed mode, which is controlled by parameter 0x2e, and the highest bit 15 is
|
||||
# the direction bit
|
||||
VELOCITY = 1
|
||||
# PWM open-loop speed regulation mode, with parameter 0x2c running time parameter control, bit11 as
|
||||
# direction bit
|
||||
PWM = 2
|
||||
# In step servo mode, the number of step progress is represented by parameter 0x2a, and the highest bit 15
|
||||
# is the direction bit
|
||||
STEP = 3
|
||||
|
||||
|
||||
class DriveMode(Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class TorqueMode(Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
def _split_into_byte_chunks(value: int, length: int) -> list[int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
if length == 1:
|
||||
data = [value]
|
||||
elif length == 2:
|
||||
data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)]
|
||||
elif length == 4:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
|
||||
]
|
||||
return data
|
||||
|
||||
|
||||
def patch_setPacketTimeout(self, packet_length): # noqa: N802
|
||||
"""
|
||||
HACK: This patches the PortHandler behavior to set the correct packet timeouts.
|
||||
|
||||
It fixes https://gitee.com/ftservo/SCServoSDK/issues/IBY2S6
|
||||
The bug is fixed on the official Feetech SDK repo (https://gitee.com/ftservo/FTServo_Python)
|
||||
but because that version is not published on PyPI, we rely on the (unofficial) on that is, which needs
|
||||
patching.
|
||||
"""
|
||||
self.packet_start_time = self.getCurrentTime()
|
||||
self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50
|
||||
|
||||
|
||||
class FeetechMotorsBus(MotorsBus):
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the
|
||||
python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk.
|
||||
"""
|
||||
|
||||
available_baudrates = deepcopy(SCAN_BAUDRATES)
|
||||
default_timeout = DEFAULT_TIMEOUT_MS
|
||||
model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE)
|
||||
model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
model_encoding_table = deepcopy(MODEL_ENCODING_TABLE)
|
||||
model_number_table = deepcopy(MODEL_NUMBER_TABLE)
|
||||
model_resolution_table = deepcopy(MODEL_RESOLUTION)
|
||||
normalized_data = deepcopy(NORMALIZED_DATA)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
protocol_version: int = DEFAULT_PROTOCOL_VERSION,
|
||||
):
|
||||
super().__init__(port, motors, calibration)
|
||||
self.protocol_version = protocol_version
|
||||
self._assert_same_protocol()
|
||||
import scservo_sdk as scs
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
# HACK: monkeypatch
|
||||
self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__(
|
||||
self.port_handler, scs.PortHandler
|
||||
)
|
||||
self.packet_handler = scs.PacketHandler(protocol_version)
|
||||
self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0)
|
||||
self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0)
|
||||
self._comm_success = scs.COMM_SUCCESS
|
||||
self._no_error = 0x00
|
||||
|
||||
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||
raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}")
|
||||
|
||||
def _assert_same_protocol(self) -> None:
|
||||
if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models):
|
||||
raise RuntimeError("Some motors use an incompatible protocol.")
|
||||
|
||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||
if instruction_name == "sync_read" and self.protocol_version == 1:
|
||||
raise NotImplementedError(
|
||||
"'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead."
|
||||
)
|
||||
if instruction_name == "broadcast_ping" and self.protocol_version == 1:
|
||||
raise NotImplementedError(
|
||||
"'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead."
|
||||
)
|
||||
|
||||
def _assert_same_firmware(self) -> None:
|
||||
firmware_versions = self._read_firmware_version(self.ids)
|
||||
if len(set(firmware_versions.values())) != 1:
|
||||
raise RuntimeError(
|
||||
"Some Motors use different firmware versions. Update their firmware first using Feetech's software. "
|
||||
"Visit https://www.feetechrc.com/software."
|
||||
)
|
||||
|
||||
def _handshake(self) -> None:
|
||||
self._assert_motors_exist()
|
||||
self._assert_same_firmware()
|
||||
|
||||
def configure_motors(self) -> None:
|
||||
for motor in self.motors:
|
||||
# By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on
|
||||
# the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0).
|
||||
self.write("Return_Delay_Time", motor, 0)
|
||||
# Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors.
|
||||
# Note: this address is not in the official STS3215 Memory Table
|
||||
self.write("Maximum_Acceleration", motor, 254)
|
||||
self.write("Acceleration", motor, 254)
|
||||
|
||||
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
||||
"""
|
||||
On Feetech Motors:
|
||||
Present_Position = Actual_Position - Homing_Offset
|
||||
"""
|
||||
half_turn_homings = {}
|
||||
for motor, pos in positions.items():
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
half_turn_homings[motor] = pos - int(max_res / 2)
|
||||
|
||||
return half_turn_homings
|
||||
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.DISABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", name, 0, num_retry=num_retry)
|
||||
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
for name in self._get_motors_list(motors):
|
||||
self.write("Torque_Enable", name, TorqueMode.ENABLED.value, num_retry=num_retry)
|
||||
self.write("Lock", name, 1, num_retry=num_retry)
|
||||
|
||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
for id_ in ids_values:
|
||||
model = self._id_to_model(id_)
|
||||
encoding_table = self.model_encoding_table.get(model)
|
||||
if encoding_table and data_name in encoding_table:
|
||||
sign_bit = encoding_table[data_name]
|
||||
ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit)
|
||||
|
||||
return ids_values
|
||||
|
||||
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
for id_ in ids_values:
|
||||
model = self._id_to_model(id_)
|
||||
encoding_table = self.model_encoding_table.get(model)
|
||||
if encoding_table and data_name in encoding_table:
|
||||
sign_bit = encoding_table[data_name]
|
||||
ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit)
|
||||
|
||||
return ids_values
|
||||
|
||||
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||
return _split_into_byte_chunks(value, length)
|
||||
|
||||
def _broadcast_ping_p1(
|
||||
self, known_motors_only: bool = True, n_motors: int | None = None, num_retry: int = 0
|
||||
) -> dict[int, int]:
|
||||
if known_motors_only:
|
||||
ids = self.ids
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
ids = range(scs.MAX_ID + 1)
|
||||
|
||||
ids_models = {}
|
||||
motors_found = 0
|
||||
for id_ in ids:
|
||||
model_number = self.ping(id_, num_retry)
|
||||
if model_number is not None:
|
||||
ids_models[id_] = model_number
|
||||
motors_found += 1
|
||||
if motors_found >= n_motors:
|
||||
break
|
||||
|
||||
return ids_models
|
||||
|
||||
def _broadcast_ping_p0(self) -> tuple[dict[int, int], int]:
|
||||
import scservo_sdk as scs
|
||||
|
||||
data_list = {}
|
||||
|
||||
status_length = 6
|
||||
|
||||
rx_length = 0
|
||||
wait_length = status_length * scs.MAX_ID
|
||||
|
||||
txpacket = [0] * 6
|
||||
|
||||
tx_time_per_byte = (1000.0 / self.port_handler.getBaudRate()) * 10.0
|
||||
|
||||
txpacket[scs.PKT_ID] = scs.BROADCAST_ID
|
||||
txpacket[scs.PKT_LENGTH] = 2
|
||||
txpacket[scs.PKT_INSTRUCTION] = scs.INST_PING
|
||||
|
||||
result = self.packet_handler.txPacket(self.port_handler, txpacket)
|
||||
if result != scs.COMM_SUCCESS:
|
||||
self.port_handler.is_using = False
|
||||
return data_list, result
|
||||
|
||||
# set rx timeout
|
||||
self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0)
|
||||
|
||||
rxpacket = []
|
||||
while True:
|
||||
rxpacket += self.port_handler.readPort(wait_length - rx_length)
|
||||
rx_length = len(rxpacket)
|
||||
|
||||
if self.port_handler.isPacketTimeout(): # or rx_length >= wait_length
|
||||
break
|
||||
|
||||
self.port_handler.is_using = False
|
||||
|
||||
if rx_length == 0:
|
||||
return data_list, scs.COMM_RX_TIMEOUT
|
||||
|
||||
while True:
|
||||
if rx_length < status_length:
|
||||
return data_list, scs.COMM_RX_CORRUPT
|
||||
|
||||
# find packet header
|
||||
for idx in range(0, (rx_length - 1)):
|
||||
if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF):
|
||||
break
|
||||
|
||||
if idx == 0: # found at the beginning of the packet
|
||||
# calculate checksum
|
||||
checksum = 0
|
||||
for idx in range(2, status_length - 1): # except header & checksum
|
||||
checksum += rxpacket[idx]
|
||||
|
||||
checksum = ~checksum & 0xFF
|
||||
if rxpacket[status_length - 1] == checksum:
|
||||
result = scs.COMM_SUCCESS
|
||||
data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR]
|
||||
|
||||
del rxpacket[0:status_length]
|
||||
rx_length = rx_length - status_length
|
||||
|
||||
if rx_length == 0:
|
||||
return data_list, result
|
||||
else:
|
||||
result = scs.COMM_RX_CORRUPT
|
||||
# remove header (0xFF 0xFF)
|
||||
del rxpacket[0:2]
|
||||
rx_length = rx_length - 2
|
||||
else:
|
||||
# remove unnecessary packets
|
||||
del rxpacket[0:idx]
|
||||
rx_length = rx_length - idx
|
||||
|
||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||
self._assert_protocol_is_compatible("broadcast_ping")
|
||||
for n_try in range(1 + num_retry):
|
||||
ids_status, comm = self._broadcast_ping_p0()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})")
|
||||
logger.debug(self.packet_handler.getTxRxResult(comm))
|
||||
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
return
|
||||
|
||||
ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)}
|
||||
if ids_errors:
|
||||
display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()}
|
||||
logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}")
|
||||
|
||||
return self._read_model_number(list(ids_status), raise_on_error)
|
||||
|
||||
def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]:
|
||||
firmware_versions = {}
|
||||
for id_ in motor_ids:
|
||||
firm_ver_major, comm, error = self._read(
|
||||
*FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
firm_ver_minor, comm, error = self._read(
|
||||
*FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error
|
||||
)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}"
|
||||
|
||||
return firmware_versions
|
||||
|
||||
def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]:
|
||||
model_numbers = {}
|
||||
for id_ in motor_ids:
|
||||
model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error)
|
||||
if not self._is_comm_success(comm) or self._is_error(error):
|
||||
return
|
||||
|
||||
model_numbers[id_] = model_nb
|
||||
|
||||
return model_numbers
|
||||
202
lerobot/common/motors/feetech/tables.py
Normal file
202
lerobot/common/motors/feetech/tables.py
Normal file
@@ -0,0 +1,202 @@
|
||||
FIRMWARE_MAJOR_VERSION = (0, 1)
|
||||
FIRMWARE_MINOR_VERSION = (1, 1)
|
||||
MODEL_NUMBER = (3, 2)
|
||||
|
||||
# See this link for STS3215 Memory Table:
|
||||
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
||||
# data_name: (address, size_byte)
|
||||
STS_SMS_SERIES_CONTROL_TABLE = {
|
||||
# EPROM
|
||||
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
"Model_Number": MODEL_NUMBER, # read-only
|
||||
"ID": (5, 1),
|
||||
"Baud_Rate": (6, 1),
|
||||
"Return_Delay_Time": (7, 1),
|
||||
"Response_Status_Level": (8, 1),
|
||||
"Min_Position_Limit": (9, 2),
|
||||
"Max_Position_Limit": (11, 2),
|
||||
"Max_Temperature_Limit": (13, 1),
|
||||
"Max_Voltage_Limit": (14, 1),
|
||||
"Min_Voltage_Limit": (15, 1),
|
||||
"Max_Torque_Limit": (16, 2),
|
||||
"Phase": (18, 1),
|
||||
"Unloading_Condition": (19, 1),
|
||||
"LED_Alarm_Condition": (20, 1),
|
||||
"P_Coefficient": (21, 1),
|
||||
"D_Coefficient": (22, 1),
|
||||
"I_Coefficient": (23, 1),
|
||||
"Minimum_Startup_Force": (24, 2),
|
||||
"CW_Dead_Zone": (26, 1),
|
||||
"CCW_Dead_Zone": (27, 1),
|
||||
"Protection_Current": (28, 2),
|
||||
"Angular_Resolution": (30, 1),
|
||||
"Homing_Offset": (31, 2),
|
||||
"Operating_Mode": (33, 1),
|
||||
"Protective_Torque": (34, 1),
|
||||
"Protection_Time": (35, 1),
|
||||
"Overload_Torque": (36, 1),
|
||||
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
|
||||
"Over_Current_Protection_Time": (38, 1),
|
||||
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
|
||||
# SRAM
|
||||
"Torque_Enable": (40, 1),
|
||||
"Acceleration": (41, 1),
|
||||
"Goal_Position": (42, 2),
|
||||
"Goal_Time": (44, 2),
|
||||
"Goal_Speed": (46, 2),
|
||||
"Torque_Limit": (48, 2),
|
||||
"Lock": (55, 1),
|
||||
"Present_Position": (56, 2), # read-only
|
||||
"Present_Speed": (58, 2), # read-only
|
||||
"Present_Load": (60, 2), # read-only
|
||||
"Present_Voltage": (62, 1), # read-only
|
||||
"Present_Temperature": (63, 1), # read-only
|
||||
"Status": (65, 1), # read-only
|
||||
"Moving": (66, 1), # read-only
|
||||
"Present_Current": (69, 2), # read-only
|
||||
# Not in the Memory Table
|
||||
"Maximum_Acceleration": (85, 2),
|
||||
}
|
||||
|
||||
SCS_SERIES_CONTROL_TABLE = {
|
||||
# EPROM
|
||||
"Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only
|
||||
"Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only
|
||||
"Model_Number": MODEL_NUMBER, # read-only
|
||||
"ID": (5, 1),
|
||||
"Baud_Rate": (6, 1),
|
||||
"Return_Delay": (7, 1),
|
||||
"Response_Status_Level": (8, 1),
|
||||
"Min_Position_Limit": (9, 2),
|
||||
"Max_Position_Limit": (11, 2),
|
||||
"Max_Temperature_Limit": (13, 1),
|
||||
"Max_Voltage_Limit": (14, 1),
|
||||
"Min_Voltage_Limit": (15, 1),
|
||||
"Max_Torque_Limit": (16, 2),
|
||||
"Phase": (18, 1),
|
||||
"Unloading_Condition": (19, 1),
|
||||
"LED_Alarm_Condition": (20, 1),
|
||||
"P_Coefficient": (21, 1),
|
||||
"D_Coefficient": (22, 1),
|
||||
"I_Coefficient": (23, 1),
|
||||
"Minimum_Startup_Force": (24, 2),
|
||||
"CW_Dead_Zone": (26, 1),
|
||||
"CCW_Dead_Zone": (27, 1),
|
||||
"Protective_Torque": (37, 1),
|
||||
"Protection_Time": (38, 1),
|
||||
# SRAM
|
||||
"Torque_Enable": (40, 1),
|
||||
"Acceleration": (41, 1),
|
||||
"Goal_Position": (42, 2),
|
||||
"Running_Time": (44, 2),
|
||||
"Goal_Speed": (46, 2),
|
||||
"Lock": (48, 1),
|
||||
"Present_Position": (56, 2), # read-only
|
||||
"Present_Speed": (58, 2), # read-only
|
||||
"Present_Load": (60, 2), # read-only
|
||||
"Present_Voltage": (62, 1), # read-only
|
||||
"Present_Temperature": (63, 1), # read-only
|
||||
"Sync_Write_Flag": (64, 1), # read-only
|
||||
"Status": (65, 1), # read-only
|
||||
"Moving": (66, 1), # read-only
|
||||
}
|
||||
|
||||
STS_SMS_SERIES_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
3: 128_000,
|
||||
4: 115_200,
|
||||
5: 57_600,
|
||||
6: 38_400,
|
||||
7: 19_200,
|
||||
}
|
||||
|
||||
SCS_SERIES_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
3: 128_000,
|
||||
4: 115_200,
|
||||
5: 57_600,
|
||||
6: 38_400,
|
||||
7: 19_200,
|
||||
}
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"sts_series": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
||||
"sms_series": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"sts3215": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"sts3250": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
"scs0009": SCS_SERIES_CONTROL_TABLE,
|
||||
"sm8512bl": STS_SMS_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"sts_series": 4096,
|
||||
"sms_series": 4096,
|
||||
"scs_series": 1024,
|
||||
"sts3215": 4096,
|
||||
"sts3250": 4096,
|
||||
"sm8512bl": 65536,
|
||||
"scs0009": 1024,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"sts_series": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"sms_series": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
|
||||
"sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"sts3215": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"sts3250": STS_SMS_SERIES_BAUDRATE_TABLE,
|
||||
"scs0009": SCS_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
# Sign-Magnitude encoding bits
|
||||
STS_SMS_SERIES_ENCODINGS_TABLE = {
|
||||
"Homing_Offset": 11,
|
||||
"Goal_Speed": 15,
|
||||
}
|
||||
|
||||
MODEL_ENCODING_TABLE = {
|
||||
"sts_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"sms_series": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"scs_series": {},
|
||||
"sts3215": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"sts3250": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE,
|
||||
"scs0009": {},
|
||||
}
|
||||
|
||||
SCAN_BAUDRATES = [
|
||||
4_800,
|
||||
9_600,
|
||||
14_400,
|
||||
19_200,
|
||||
38_400,
|
||||
57_600,
|
||||
115_200,
|
||||
128_000,
|
||||
250_000,
|
||||
500_000,
|
||||
1_000_000,
|
||||
]
|
||||
|
||||
MODEL_NUMBER_TABLE = {
|
||||
"sts3215": 777,
|
||||
"sts3250": 2825,
|
||||
"sm8512bl": 11272,
|
||||
"scs0009": 1284,
|
||||
}
|
||||
|
||||
MODEL_PROTOCOL = {
|
||||
"sts_series": 0,
|
||||
"sms_series": 0,
|
||||
"scs_series": 1,
|
||||
"sts3215": 0,
|
||||
"sts3250": 0,
|
||||
"sm8512bl": 0,
|
||||
"scs0009": 1,
|
||||
}
|
||||
987
lerobot/common/motors/motors_bus.py
Normal file
987
lerobot/common/motors/motors_bus.py
Normal file
@@ -0,0 +1,987 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# ruff: noqa: N802
|
||||
# This noqa is for the Protocols classes: PortHandler, PacketHandler GroupSyncRead/Write
|
||||
# TODO(aliberts): Add block noqa when feature below is available
|
||||
# https://github.com/astral-sh/ruff/issues/3711
|
||||
|
||||
import abc
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pprint import pformat
|
||||
from typing import Protocol, TypeAlias
|
||||
|
||||
import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
Value: TypeAlias = int | float
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]:
|
||||
ctrl_table = model_ctrl_table.get(model)
|
||||
if ctrl_table is None:
|
||||
raise KeyError(f"Control table for {model=} not found.")
|
||||
return ctrl_table
|
||||
|
||||
|
||||
def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]:
|
||||
ctrl_table = get_ctrl_table(model_ctrl_table, model)
|
||||
addr_bytes = ctrl_table.get(data_name)
|
||||
if addr_bytes is None:
|
||||
raise KeyError(f"Address for '{data_name}' not found in {model} control table.")
|
||||
return addr_bytes
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None:
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = get_address(model_ctrl_table, model, data_name)
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}'"
|
||||
f"({list(zip(motor_models, all_addr, strict=False))})."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}'"
|
||||
f"({list(zip(motor_models, all_bytes, strict=False))})."
|
||||
)
|
||||
|
||||
|
||||
class MotorNormMode(Enum):
|
||||
DEGREE = 0
|
||||
RANGE_0_100 = 1
|
||||
RANGE_M100_100 = 2
|
||||
VELOCITY = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorCalibration:
|
||||
id: int
|
||||
drive_mode: int
|
||||
homing_offset: int
|
||||
range_min: int
|
||||
range_max: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Motor:
|
||||
id: int
|
||||
model: str
|
||||
norm_mode: MotorNormMode
|
||||
|
||||
|
||||
class JointOutOfRangeError(Exception):
|
||||
def __init__(self, message="Joint is out of range"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class PortHandler(Protocol):
|
||||
def __init__(self, port_name):
|
||||
self.is_open: bool
|
||||
self.baudrate: int
|
||||
self.packet_start_time: float
|
||||
self.packet_timeout: float
|
||||
self.tx_time_per_byte: float
|
||||
self.is_using: bool
|
||||
self.port_name: str
|
||||
self.ser: serial.Serial
|
||||
|
||||
def openPort(self): ...
|
||||
def closePort(self): ...
|
||||
def clearPort(self): ...
|
||||
def setPortName(self, port_name): ...
|
||||
def getPortName(self): ...
|
||||
def setBaudRate(self, baudrate): ...
|
||||
def getBaudRate(self): ...
|
||||
def getBytesAvailable(self): ...
|
||||
def readPort(self, length): ...
|
||||
def writePort(self, packet): ...
|
||||
def setPacketTimeout(self, packet_length): ...
|
||||
def setPacketTimeoutMillis(self, msec): ...
|
||||
def isPacketTimeout(self): ...
|
||||
def getCurrentTime(self): ...
|
||||
def getTimeSinceStart(self): ...
|
||||
def setupPort(self, cflag_baud): ...
|
||||
def getCFlagBaud(self, baudrate): ...
|
||||
|
||||
|
||||
class PacketHandler(Protocol):
|
||||
def getTxRxResult(self, result): ...
|
||||
def getRxPacketError(self, error): ...
|
||||
def txPacket(self, port, txpacket): ...
|
||||
def rxPacket(self, port): ...
|
||||
def txRxPacket(self, port, txpacket): ...
|
||||
def ping(self, port, id): ...
|
||||
def action(self, port, id): ...
|
||||
def readTx(self, port, id, address, length): ...
|
||||
def readRx(self, port, id, length): ...
|
||||
def readTxRx(self, port, id, address, length): ...
|
||||
def read1ByteTx(self, port, id, address): ...
|
||||
def read1ByteRx(self, port, id): ...
|
||||
def read1ByteTxRx(self, port, id, address): ...
|
||||
def read2ByteTx(self, port, id, address): ...
|
||||
def read2ByteRx(self, port, id): ...
|
||||
def read2ByteTxRx(self, port, id, address): ...
|
||||
def read4ByteTx(self, port, id, address): ...
|
||||
def read4ByteRx(self, port, id): ...
|
||||
def read4ByteTxRx(self, port, id, address): ...
|
||||
def writeTxOnly(self, port, id, address, length, data): ...
|
||||
def writeTxRx(self, port, id, address, length, data): ...
|
||||
def write1ByteTxOnly(self, port, id, address, data): ...
|
||||
def write1ByteTxRx(self, port, id, address, data): ...
|
||||
def write2ByteTxOnly(self, port, id, address, data): ...
|
||||
def write2ByteTxRx(self, port, id, address, data): ...
|
||||
def write4ByteTxOnly(self, port, id, address, data): ...
|
||||
def write4ByteTxRx(self, port, id, address, data): ...
|
||||
def regWriteTxOnly(self, port, id, address, length, data): ...
|
||||
def regWriteTxRx(self, port, id, address, length, data): ...
|
||||
def syncReadTx(self, port, start_address, data_length, param, param_length): ...
|
||||
def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ...
|
||||
|
||||
|
||||
class GroupSyncRead(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.last_result: bool
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id): ...
|
||||
def removeParam(self, id): ...
|
||||
def clearParam(self): ...
|
||||
def txPacket(self): ...
|
||||
def rxPacket(self): ...
|
||||
def txRxPacket(self): ...
|
||||
def isAvailable(self, id, address, data_length): ...
|
||||
def getData(self, id, address, data_length): ...
|
||||
|
||||
|
||||
class GroupSyncWrite(Protocol):
|
||||
def __init__(self, port, ph, start_address, data_length):
|
||||
self.port: str
|
||||
self.ph: PortHandler
|
||||
self.start_address: int
|
||||
self.data_length: int
|
||||
self.is_param_changed: bool
|
||||
self.param: list
|
||||
self.data_dict: dict
|
||||
|
||||
def makeParam(self): ...
|
||||
def addParam(self, id, data): ...
|
||||
def removeParam(self, id): ...
|
||||
def changeParam(self, id, data): ...
|
||||
def clearParam(self): ...
|
||||
def txPacket(self): ...
|
||||
|
||||
|
||||
class MotorsBus(abc.ABC):
|
||||
"""The main LeRobot class for implementing motors buses.
|
||||
|
||||
There are currently two implementations of this abstract class:
|
||||
- DynamixelMotorsBus
|
||||
- FeetechMotorsBus
|
||||
|
||||
Note: This class may evolve in the future should we add support for other manufacturers SDKs.
|
||||
|
||||
A MotorsBus allows to efficiently read and write to the attached motors.
|
||||
It represents several motors daisy-chained together and connected through a serial port.
|
||||
|
||||
A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
>>> Finding all available ports for the MotorsBus.
|
||||
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
>>> Remove the usb cable from your MotorsBus and press Enter when done.
|
||||
>>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example of usage for 1 Feetech sts3215 motor connected to the bus:
|
||||
```python
|
||||
motors_bus = FeetechMotorsBus(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={"gripper": (6, "sts3215")},
|
||||
)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# Move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# When done, properly disconnect the port using
|
||||
motors_bus.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
available_baudrates: list[int]
|
||||
default_timeout: int
|
||||
model_baudrate_table: dict[str, dict]
|
||||
model_ctrl_table: dict[str, dict]
|
||||
model_encoding_table: dict[str, dict]
|
||||
model_number_table: dict[str, int]
|
||||
model_resolution_table: dict[str, int]
|
||||
normalized_data: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: str,
|
||||
motors: dict[str, Motor],
|
||||
calibration: dict[str, MotorCalibration] | None = None,
|
||||
):
|
||||
self.port = port
|
||||
self.motors = motors
|
||||
self.calibration = calibration if calibration else {}
|
||||
|
||||
self.port_handler: PortHandler
|
||||
self.packet_handler: PacketHandler
|
||||
self.sync_reader: GroupSyncRead
|
||||
self.sync_writer: GroupSyncWrite
|
||||
self._comm_success: int
|
||||
self._no_error: int
|
||||
|
||||
self._id_to_model_dict = {m.id: m.model for m in self.motors.values()}
|
||||
self._id_to_name_dict = {m.id: name for name, m in self.motors.items()}
|
||||
self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()}
|
||||
|
||||
self._validate_motors()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.motors)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Port: '{self.port}',\n"
|
||||
f" Motors: \n{pformat(self.motors, indent=8, sort_dicts=False)},\n"
|
||||
")',\n"
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _has_different_ctrl_tables(self) -> bool:
|
||||
if len(self.models) < 2:
|
||||
return False
|
||||
|
||||
first_table = self.model_ctrl_table[self.models[0]]
|
||||
return any(
|
||||
DeepDiff(first_table, get_ctrl_table(self.model_ctrl_table, model)) for model in self.models[1:]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def names(self) -> list[str]:
|
||||
return list(self.motors)
|
||||
|
||||
@cached_property
|
||||
def models(self) -> list[str]:
|
||||
return [m.model for m in self.motors.values()]
|
||||
|
||||
@cached_property
|
||||
def ids(self) -> list[int]:
|
||||
return [m.id for m in self.motors.values()]
|
||||
|
||||
def _model_nb_to_model(self, motor_nb: int) -> str:
|
||||
return self._model_nb_to_model_dict[motor_nb]
|
||||
|
||||
def _id_to_model(self, motor_id: int) -> str:
|
||||
return self._id_to_model_dict[motor_id]
|
||||
|
||||
def _id_to_name(self, motor_id: int) -> str:
|
||||
return self._id_to_name_dict[motor_id]
|
||||
|
||||
def _get_motor_id(self, motor: NameOrID) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].id
|
||||
elif isinstance(motor, int):
|
||||
return motor
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motor_model(self, motor: NameOrID) -> int:
|
||||
if isinstance(motor, str):
|
||||
return self.motors[motor].model
|
||||
elif isinstance(motor, int):
|
||||
return self._id_to_model_dict[motor]
|
||||
else:
|
||||
raise TypeError(f"'{motor}' should be int, str.")
|
||||
|
||||
def _get_motors_list(self, motors: str | list[str] | None) -> list[str]:
|
||||
if motors is None:
|
||||
return self.names
|
||||
elif isinstance(motors, str):
|
||||
return [motors]
|
||||
elif isinstance(motors, list):
|
||||
return motors.copy()
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int, float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||
else:
|
||||
raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}")
|
||||
|
||||
def _validate_motors(self) -> None:
|
||||
if len(self.ids) != len(set(self.ids)):
|
||||
raise ValueError(f"Some motors have the same id!\n{self}")
|
||||
|
||||
# Ensure ctrl table available for all models
|
||||
for model in self.models:
|
||||
get_ctrl_table(self.model_ctrl_table, model)
|
||||
|
||||
def _is_comm_success(self, comm: int) -> bool:
|
||||
return comm == self._comm_success
|
||||
|
||||
def _is_error(self, error: int) -> bool:
|
||||
return error != self._no_error
|
||||
|
||||
def _assert_motors_exist(self) -> None:
|
||||
# TODO(aliberts): collect all wrong ids/models and display them at once
|
||||
found_models = {}
|
||||
for id_ in self.ids:
|
||||
model_nb = self.ping(id_)
|
||||
if model_nb is not None:
|
||||
found_models[id_] = model_nb
|
||||
expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()}
|
||||
if set(found_models) != set(self.ids):
|
||||
raise RuntimeError(
|
||||
f"{self.__class__.__name__} is supposed to have these motors: ({{id: model_nb}})"
|
||||
f"\n{pformat(expected_models, indent=4, sort_dicts=False)}\n"
|
||||
f"But it found these motors on port '{self.port}':"
|
||||
f"\n{pformat(found_models, indent=4, sort_dicts=False)}\n"
|
||||
)
|
||||
|
||||
for id_, model in expected_models.items():
|
||||
if found_models[id_] != model:
|
||||
raise RuntimeError(
|
||||
f"Motor '{self._id_to_name(id_)}' (id={id_}) is supposed to be of model_number={model} "
|
||||
f"('{self._id_to_model(id_)}') but a model_number={found_models[id_]} "
|
||||
"was found instead for that id."
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _assert_protocol_is_compatible(self, instruction_name: str) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.port_handler.is_open
|
||||
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
elif handshake:
|
||||
self._handshake()
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
||||
) from e
|
||||
|
||||
self.set_timeout()
|
||||
logger.debug(f"{self.__class__.__name__} connected.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _handshake(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]:
|
||||
bus = cls(port, {}, *args, **kwargs)
|
||||
try:
|
||||
bus.port_handler.openPort()
|
||||
except (FileNotFoundError, OSError, serial.SerialException) as e:
|
||||
raise ConnectionError(
|
||||
f"Could not connect to port '{port}'. Make sure you are using the correct port."
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py`\n"
|
||||
) from e
|
||||
baudrate_ids = {}
|
||||
for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"):
|
||||
bus.set_baudrate(baudrate)
|
||||
ids_models = bus.broadcast_ping()
|
||||
if ids_models:
|
||||
tqdm.write(f"Motors found for {baudrate=}: {pformat(ids_models, indent=4)}")
|
||||
baudrate_ids[baudrate] = list(ids_models)
|
||||
|
||||
return baudrate_ids
|
||||
|
||||
@abc.abstractmethod
|
||||
def configure_motors(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None:
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def torque_disabled(self):
|
||||
self.disable_torque()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.enable_torque()
|
||||
|
||||
def set_timeout(self, timeout_ms: int | None = None):
|
||||
timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout
|
||||
self.port_handler.setPacketTimeoutMillis(timeout_ms)
|
||||
|
||||
def get_baudrate(self) -> int:
|
||||
return self.port_handler.getBaudRate()
|
||||
|
||||
def set_baudrate(self, baudrate: int) -> None:
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.calibration == self.read_calibration()
|
||||
|
||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||
offsets = self.sync_read("Homing_Offset", normalize=False)
|
||||
mins = self.sync_read("Min_Position_Limit", normalize=False)
|
||||
maxes = self.sync_read("Max_Position_Limit", normalize=False)
|
||||
|
||||
try:
|
||||
drive_modes = self.sync_read("Drive_Mode", normalize=False)
|
||||
except KeyError:
|
||||
drive_modes = dict.fromkeys(self.names, 0)
|
||||
|
||||
calibration = {}
|
||||
for name, motor in self.motors.items():
|
||||
calibration[name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_modes[name],
|
||||
homing_offset=offsets[name],
|
||||
range_min=mins[name],
|
||||
range_max=maxes[name],
|
||||
)
|
||||
|
||||
return calibration
|
||||
|
||||
def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None:
|
||||
for motor, calibration in calibration_dict.items():
|
||||
self.write("Homing_Offset", motor, calibration.homing_offset)
|
||||
self.write("Min_Position_Limit", motor, calibration.range_min)
|
||||
self.write("Max_Position_Limit", motor, calibration.range_max)
|
||||
|
||||
self.calibration = calibration_dict
|
||||
|
||||
def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None:
|
||||
if motors is None:
|
||||
motors = self.names
|
||||
elif isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
for motor in motors:
|
||||
model = self._get_motor_model(motor)
|
||||
max_res = self.model_resolution_table[model] - 1
|
||||
self.write("Homing_Offset", motor, 0, normalize=False)
|
||||
self.write("Min_Position_Limit", motor, 0, normalize=False)
|
||||
self.write("Max_Position_Limit", motor, max_res, normalize=False)
|
||||
|
||||
self.calibration = {}
|
||||
|
||||
def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]:
|
||||
"""
|
||||
This assumes motors present positions are roughly in the middle of their desired range
|
||||
|
||||
Step 1: Set homing and min max to 0
|
||||
|
||||
Step 2: Read Present_Position which will be Actual_Position since
|
||||
Present_Position = Actual_Position ± Homing_Offset (1)
|
||||
and Homing_Offset = 0 from step 1
|
||||
|
||||
Step 3: We want to set the Homing_Offset such that the current Present_Position to be half range of 1
|
||||
revolution. For instance, if 1 revolution corresponds to 4095 (4096 steps), this means we want the
|
||||
current Present_Position to be 2047.
|
||||
|
||||
In that example:
|
||||
Present_Position = 2047 (2)
|
||||
Actual_Position = X (read in step 2)
|
||||
from (1) and (2):
|
||||
=> Homing_Offset = ±(X - 2048)
|
||||
"""
|
||||
if motors is None:
|
||||
motors = self.names
|
||||
elif isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
else:
|
||||
raise TypeError(motors)
|
||||
|
||||
self.reset_calibration(motors)
|
||||
actual_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
homing_offsets = self._get_half_turn_homings(actual_positions)
|
||||
for motor, offset in homing_offsets.items():
|
||||
self.write("Homing_Offset", motor, offset)
|
||||
|
||||
return homing_offsets
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]:
|
||||
pass
|
||||
|
||||
def record_ranges_of_motion(
|
||||
self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True
|
||||
) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]:
|
||||
"""
|
||||
This assumes that the homing offsets have been set such that all possible values in the range of
|
||||
motion are positive and that the zero is not crossed. To that end, `set_half_turn_homings` should
|
||||
typically be called prior to this.
|
||||
"""
|
||||
if motors is None:
|
||||
motors = self.names
|
||||
elif isinstance(motors, (str, int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
start_positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = start_positions.copy()
|
||||
maxes = start_positions.copy()
|
||||
while True:
|
||||
positions = self.sync_read("Present_Position", motors, normalize=False)
|
||||
mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()}
|
||||
maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()}
|
||||
|
||||
if display_values:
|
||||
print("\n-------------------------------------------")
|
||||
print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}")
|
||||
for name in motors:
|
||||
print(f"{name:<15} | {mins[name]:>6} | {positions[name]:>6} | {maxes[name]:>6}")
|
||||
|
||||
if enter_pressed():
|
||||
break
|
||||
|
||||
if display_values:
|
||||
# Move cursor up to overwrite the previous output
|
||||
move_cursor_up(len(motors) + 3)
|
||||
|
||||
return mins, maxes
|
||||
|
||||
def _normalize(self, data_name: str, ids_values: dict[int, int]) -> dict[int, float]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError(f"{self} has no calibration registered.")
|
||||
|
||||
normalized_values = {}
|
||||
for id_, val in ids_values.items():
|
||||
name = self._id_to_name(id_)
|
||||
min_ = self.calibration[name].range_min
|
||||
max_ = self.calibration[name].range_max
|
||||
bounded_val = min(max_, max(min_, val))
|
||||
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
|
||||
normalized_values[id_] = (((bounded_val - min_) / (max_ - min_)) * 200) - 100
|
||||
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
|
||||
normalized_values[id_] = ((bounded_val - min_) / (max_ - min_)) * 100
|
||||
else:
|
||||
# TODO(alibers): velocity and degree modes
|
||||
raise NotImplementedError
|
||||
|
||||
return normalized_values
|
||||
|
||||
def _unnormalize(self, data_name: str, ids_values: dict[int, float]) -> dict[int, int]:
|
||||
if not self.calibration:
|
||||
raise RuntimeError(f"{self} has no calibration registered.")
|
||||
|
||||
unnormalized_values = {}
|
||||
for id_, val in ids_values.items():
|
||||
name = self._id_to_name(id_)
|
||||
min_ = self.calibration[name].range_min
|
||||
max_ = self.calibration[name].range_max
|
||||
if self.motors[name].norm_mode is MotorNormMode.RANGE_M100_100:
|
||||
bounded_val = min(100.0, max(-100.0, val))
|
||||
unnormalized_values[id_] = int(((bounded_val + 100) / 200) * (max_ - min_) + min_)
|
||||
elif self.motors[name].norm_mode is MotorNormMode.RANGE_0_100:
|
||||
bounded_val = min(100.0, max(0.0, val))
|
||||
unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_)
|
||||
else:
|
||||
# TODO(alibers): velocity and degree modes
|
||||
raise NotImplementedError
|
||||
|
||||
return unnormalized_values
|
||||
|
||||
@abc.abstractmethod
|
||||
def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]:
|
||||
pass
|
||||
|
||||
def _serialize_data(self, value: int, length: int) -> list[int]:
|
||||
"""
|
||||
Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication
|
||||
protocol. Depending on the protocol, split values can be in big-endian or little-endian order.
|
||||
|
||||
Supported data length for both Feetech and Dynamixel:
|
||||
- 1 (for values 0 to 255)
|
||||
- 2 (for values 0 to 65,535)
|
||||
- 4 (for values 0 to 4,294,967,295)
|
||||
"""
|
||||
if value < 0:
|
||||
raise ValueError(f"Negative values are not allowed: {value}")
|
||||
|
||||
max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length)
|
||||
if max_value is None:
|
||||
raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].")
|
||||
|
||||
if value > max_value:
|
||||
raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).")
|
||||
|
||||
return self._split_into_byte_chunks(value, length)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _split_into_byte_chunks(self, value: int, length: int) -> list[int]:
|
||||
"""Convert an integer into a list of byte-sized integers."""
|
||||
pass
|
||||
|
||||
def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None:
|
||||
id_ = self._get_motor_id(motor)
|
||||
for n_try in range(1 + num_retry):
|
||||
model_number, comm, error = self.packet_handler.ping(self.port_handler, id_)
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(f"ping failed for {id_=}: {n_try=} got {comm=} {error=}")
|
||||
|
||||
if not self._is_comm_success(comm):
|
||||
if raise_on_error:
|
||||
raise ConnectionError(self.packet_handler.getTxRxResult(comm))
|
||||
else:
|
||||
return
|
||||
if self._is_error(error):
|
||||
if raise_on_error:
|
||||
raise RuntimeError(self.packet_handler.getRxPacketError(error))
|
||||
else:
|
||||
return
|
||||
|
||||
return model_number
|
||||
|
||||
@abc.abstractmethod
|
||||
def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None:
|
||||
pass
|
||||
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
motor: str,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> Value:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries."
|
||||
value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
id_value = self._decode_sign(data_name, {id_: value})
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
id_value = self._normalize(data_name, id_value)
|
||||
|
||||
return id_value[id_]
|
||||
|
||||
def _read(
|
||||
self,
|
||||
address: int,
|
||||
length: int,
|
||||
motor_id: int,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
if length == 1:
|
||||
read_fn = self.packet_handler.read1ByteTxRx
|
||||
elif length == 2:
|
||||
read_fn = self.packet_handler.read2ByteTxRx
|
||||
elif length == 4:
|
||||
read_fn = self.packet_handler.read4ByteTxRx
|
||||
else:
|
||||
raise ValueError(length)
|
||||
|
||||
for n_try in range(1 + num_retry):
|
||||
value, comm, error = read_fn(self.port_handler, motor_id, address)
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(
|
||||
f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): "
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
elif self._is_error(error) and raise_on_error:
|
||||
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||
|
||||
return value, comm, error
|
||||
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
value = self._unnormalize(data_name, {id_: value})[id_]
|
||||
|
||||
value = self._encode_sign(data_name, {id_: value})[id_]
|
||||
|
||||
err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries."
|
||||
self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _write(
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
motor_id: int,
|
||||
value: int,
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[int, int]:
|
||||
data = self._serialize_data(value, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data)
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(
|
||||
f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): "
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
elif self._is_error(error) and raise_on_error:
|
||||
raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}")
|
||||
|
||||
return comm, error
|
||||
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
motors: str | list[str] | None = None,
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> dict[str, Value]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
names = self._get_motors_list(motors)
|
||||
ids = [self.motors[name].id for name in names]
|
||||
models = [self.motors[name].model for name in names]
|
||||
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries."
|
||||
ids_values, _ = self._sync_read(
|
||||
addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg
|
||||
)
|
||||
|
||||
ids_values = self._decode_sign(data_name, ids_values)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._normalize(data_name, ids_values)
|
||||
|
||||
return {self._id_to_name(id_): value for id_, value in ids_values.items()}
|
||||
|
||||
def _sync_read(
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
motor_ids: list[int],
|
||||
*,
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> tuple[dict[int, int], int]:
|
||||
self._setup_sync_reader(motor_ids, addr, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.sync_reader.txRxPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(
|
||||
f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): "
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
|
||||
values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids}
|
||||
return values, comm
|
||||
|
||||
def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None:
|
||||
self.sync_reader.clearParam()
|
||||
self.sync_reader.start_address = addr
|
||||
self.sync_reader.data_length = length
|
||||
for id_ in motor_ids:
|
||||
self.sync_reader.addParam(id_)
|
||||
|
||||
# TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be.
|
||||
# Would have to handle the logic of checking if a packet has been sent previously though but doable.
|
||||
# This could be at the cost of increase latency between the moment the data is produced by the motors and
|
||||
# the moment it is used by a policy.
|
||||
# def _async_read(self, motor_ids: list[int], address: int, length: int):
|
||||
# if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...:
|
||||
# self._setup_sync_reader(motor_ids, address, length)
|
||||
# else:
|
||||
# self.sync_reader.rxPacket()
|
||||
# self.sync_reader.txPacket()
|
||||
|
||||
# for id_ in motor_ids:
|
||||
# value = self.sync_reader.getData(id_, address, length)
|
||||
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
values: Value | dict[str, Value],
|
||||
*,
|
||||
normalize: bool = True,
|
||||
num_retry: int = 0,
|
||||
) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
if self._has_different_ctrl_tables:
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
|
||||
model = next(iter(models))
|
||||
addr, length = get_address(self.model_ctrl_table, model, data_name)
|
||||
|
||||
if normalize and data_name in self.normalized_data:
|
||||
ids_values = self._unnormalize(data_name, ids_values)
|
||||
|
||||
ids_values = self._encode_sign(data_name, ids_values)
|
||||
|
||||
err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries."
|
||||
self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg)
|
||||
|
||||
def _sync_write(
|
||||
self,
|
||||
addr: int,
|
||||
length: int,
|
||||
ids_values: dict[int, int],
|
||||
num_retry: int = 0,
|
||||
raise_on_error: bool = True,
|
||||
err_msg: str = "",
|
||||
) -> int:
|
||||
self._setup_sync_writer(ids_values, addr, length)
|
||||
for n_try in range(1 + num_retry):
|
||||
comm = self.sync_writer.txPacket()
|
||||
if self._is_comm_success(comm):
|
||||
break
|
||||
logger.debug(
|
||||
f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): "
|
||||
+ self.packet_handler.getTxRxResult(comm)
|
||||
)
|
||||
|
||||
if not self._is_comm_success(comm) and raise_on_error:
|
||||
raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}")
|
||||
|
||||
return comm
|
||||
|
||||
def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None:
|
||||
self.sync_writer.clearParam()
|
||||
self.sync_writer.start_address = addr
|
||||
self.sync_writer.data_length = length
|
||||
for id_, value in ids_values.items():
|
||||
data = self._serialize_data(value, length)
|
||||
self.sync_writer.addParam(id_, data)
|
||||
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
self.port_handler.clearPort()
|
||||
self.port_handler.is_using = False
|
||||
self.disable_torque(num_retry=5)
|
||||
|
||||
self.port_handler.closePort()
|
||||
logger.debug(f"{self.__class__.__name__} disconnected.")
|
||||
@@ -12,37 +12,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
from .configs import MotorsBusConfig
|
||||
from .motors_bus import MotorsBus
|
||||
|
||||
|
||||
class MotorsBus(Protocol):
|
||||
def motor_names(self): ...
|
||||
def set_calibration(self): ...
|
||||
def apply_calibration(self): ...
|
||||
def revert_calibration(self): ...
|
||||
def read(self): ...
|
||||
def write(self): ...
|
||||
|
||||
|
||||
def make_motors_buses_from_configs(
|
||||
motors_bus_configs: dict[str, MotorsBusConfig],
|
||||
) -> list[MotorsBus]:
|
||||
def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]:
|
||||
motors_buses = {}
|
||||
|
||||
for key, cfg in motors_bus_configs.items():
|
||||
if cfg.type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
motors_buses[key] = DynamixelMotorsBus(cfg)
|
||||
|
||||
elif cfg.type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus
|
||||
|
||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||
|
||||
@@ -54,13 +38,16 @@ def make_motors_buses_from_configs(
|
||||
|
||||
def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
||||
if motor_type == "dynamixel":
|
||||
from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus
|
||||
from .configs import DynamixelMotorsBusConfig
|
||||
from .dynamixel import DynamixelMotorsBus
|
||||
|
||||
config = DynamixelMotorsBusConfig(**kwargs)
|
||||
return DynamixelMotorsBus(config)
|
||||
|
||||
elif motor_type == "feetech":
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
from feetech import FeetechMotorsBus
|
||||
|
||||
from .configs import FeetechMotorsBusConfig
|
||||
|
||||
config = FeetechMotorsBusConfig(**kwargs)
|
||||
return FeetechMotorsBus(config)
|
||||
@@ -14,9 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
@@ -45,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -95,76 +94,7 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
"""Configuration for multiple Adam optimizers with different parameter groups.
|
||||
|
||||
This creates a dictionary of Adam optimizers, each with its own hyperparameters.
|
||||
|
||||
Args:
|
||||
lr: Default learning rate (used if not specified for a group)
|
||||
weight_decay: Default weight decay (used if not specified for a group)
|
||||
optimizer_groups: Dictionary mapping parameter group names to their hyperparameters
|
||||
grad_clip_norm: Gradient clipping norm
|
||||
"""
|
||||
|
||||
lr: float = 1e-3
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)
|
||||
|
||||
def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Build multiple Adam optimizers.
|
||||
|
||||
Args:
|
||||
params_dict: Dictionary mapping parameter group names to lists of parameters
|
||||
The keys should match the keys in optimizer_groups
|
||||
|
||||
Returns:
|
||||
Dictionary mapping parameter group names to their optimizers
|
||||
"""
|
||||
optimizers = {}
|
||||
|
||||
for name, params in params_dict.items():
|
||||
# Get group-specific hyperparameters or use defaults
|
||||
group_config = self.optimizer_groups.get(name, {})
|
||||
|
||||
# Create optimizer with merged parameters (defaults + group-specific)
|
||||
optimizer_kwargs = {
|
||||
"lr": group_config.get("lr", self.lr),
|
||||
"betas": group_config.get("betas", (0.9, 0.999)),
|
||||
"eps": group_config.get("eps", 1e-5),
|
||||
"weight_decay": group_config.get("weight_decay", self.weight_decay),
|
||||
}
|
||||
|
||||
optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
state = optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
@@ -172,44 +102,11 @@ def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
||||
write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)
|
||||
|
||||
|
||||
def load_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
|
||||
"""Load optimizer state from disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to load the optimizer state from.
|
||||
|
||||
Returns:
|
||||
The updated optimizer(s) with loaded state.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
loaded_optimizers = {}
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
if optimizer_dir.exists():
|
||||
loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
loaded_optimizers[name] = opt
|
||||
return loaded_optimizers
|
||||
else:
|
||||
# Handle single optimizer
|
||||
return _load_single_optimizer_state(optimizer, save_dir)
|
||||
|
||||
|
||||
def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
"""Load a single optimizer's state from disk."""
|
||||
def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
|
||||
current_state_dict = optimizer.state_dict()
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
|
||||
# Handle case where 'state' key might not exist (for newly created optimizers)
|
||||
if "state" in state:
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
else:
|
||||
loaded_state_dict = {"state": {}}
|
||||
loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
|
||||
|
||||
if "param_groups" in current_state_dict:
|
||||
param_groups = deserialize_json_into_object(
|
||||
|
||||
@@ -49,11 +49,7 @@ class DiffuserSchedulerConfig(LRSchedulerConfig):
|
||||
def build(self, optimizer: Optimizer, num_training_steps: int) -> LambdaLR:
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
kwargs = {
|
||||
**asdict(self),
|
||||
"num_training_steps": num_training_steps,
|
||||
"optimizer": optimizer,
|
||||
}
|
||||
kwargs = {**asdict(self), "num_training_steps": num_training_steps, "optimizer": optimizer}
|
||||
return get_scheduler(**kwargs)
|
||||
|
||||
|
||||
@@ -75,10 +71,7 @@ class VQBeTSchedulerConfig(LRSchedulerConfig):
|
||||
progress = float(adjusted_step - self.num_warmup_steps) / float(
|
||||
max(1, num_training_steps - self.num_warmup_steps)
|
||||
)
|
||||
return max(
|
||||
0.0,
|
||||
0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)),
|
||||
)
|
||||
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.num_cycles) * 2.0 * progress)))
|
||||
|
||||
return LambdaLR(optimizer, lr_lambda, -1)
|
||||
|
||||
|
||||
@@ -241,9 +241,7 @@ class ACTTemporalEnsembler:
|
||||
# Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor
|
||||
# operations later.
|
||||
self.ensembled_actions_count = torch.ones(
|
||||
(self.chunk_size, 1),
|
||||
dtype=torch.long,
|
||||
device=self.ensembled_actions.device,
|
||||
(self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device
|
||||
)
|
||||
else:
|
||||
# self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute
|
||||
@@ -255,10 +253,7 @@ class ACTTemporalEnsembler:
|
||||
# The last action, which has no prior online average, needs to get concatenated onto the end.
|
||||
self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1)
|
||||
self.ensembled_actions_count = torch.cat(
|
||||
[
|
||||
self.ensembled_actions_count,
|
||||
torch.ones_like(self.ensembled_actions_count[-1:]),
|
||||
]
|
||||
[self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])]
|
||||
)
|
||||
# "Consume" the first action.
|
||||
action, self.ensembled_actions, self.ensembled_actions_count = (
|
||||
@@ -338,11 +333,7 @@ class ACT(nn.Module):
|
||||
# Backbone for image feature extraction.
|
||||
if self.config.image_features:
|
||||
backbone_model = getattr(torchvision.models, config.vision_backbone)(
|
||||
replace_stride_with_dilation=[
|
||||
False,
|
||||
False,
|
||||
config.replace_final_stride_with_dilation,
|
||||
],
|
||||
replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation],
|
||||
weights=config.pretrained_backbone_weights,
|
||||
norm_layer=FrozenBatchNorm2d,
|
||||
)
|
||||
@@ -436,11 +427,7 @@ class ACT(nn.Module):
|
||||
action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D)
|
||||
|
||||
if self.config.robot_state_feature:
|
||||
vae_encoder_input = [
|
||||
cls_embed,
|
||||
robot_state_embed,
|
||||
action_embed,
|
||||
] # (B, S+2, D)
|
||||
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
|
||||
else:
|
||||
vae_encoder_input = [cls_embed, action_embed]
|
||||
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
|
||||
@@ -553,10 +540,7 @@ class ACTEncoder(nn.Module):
|
||||
self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
pos_embed: Tensor | None = None,
|
||||
key_padding_mask: Tensor | None = None,
|
||||
self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask)
|
||||
@@ -619,10 +603,7 @@ class ACTDecoder(nn.Module):
|
||||
) -> Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x,
|
||||
encoder_out,
|
||||
decoder_pos_embed=decoder_pos_embed,
|
||||
encoder_pos_embed=encoder_pos_embed,
|
||||
x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed
|
||||
)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -209,10 +209,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
# ========= inference ============
|
||||
def conditional_sample(
|
||||
self,
|
||||
batch_size: int,
|
||||
global_cond: Tensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None
|
||||
) -> Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
dtype = get_dtype_from_parameters(self)
|
||||
@@ -241,8 +238,8 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2]
|
||||
global_cond_feats = [batch[OBS_ROBOT]]
|
||||
batch_size, n_obs_steps = batch[OBS_STATE].shape[:2]
|
||||
global_cond_feats = [batch[OBS_STATE]]
|
||||
# Extract image features.
|
||||
if self.config.image_features:
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
@@ -257,10 +254,7 @@ class DiffusionModel(nn.Module):
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list,
|
||||
"(n b s) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
@@ -270,15 +264,12 @@ class DiffusionModel(nn.Module):
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features,
|
||||
"(b s n) ... -> b s (n ...)",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self.config.env_state_feature:
|
||||
global_cond_feats.append(batch[OBS_ENV])
|
||||
global_cond_feats.append(batch[OBS_ENV_STATE])
|
||||
|
||||
# Concatenate features then flatten to (B, global_cond_dim).
|
||||
return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1)
|
||||
@@ -524,9 +515,7 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
@@ -644,14 +633,10 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
self.mid_modules = nn.ModuleList(
|
||||
[
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
DiffusionConditionalResidualBlock1d(
|
||||
config.down_dims[-1],
|
||||
config.down_dims[-1],
|
||||
**common_res_block_kwargs,
|
||||
config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -24,7 +24,6 @@ from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.common.envs.utils import env_to_policy_features
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
@@ -60,14 +59,6 @@ def get_policy_class(name: str) -> PreTrainedPolicy:
|
||||
from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
|
||||
|
||||
return PI0FASTPolicy
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy
|
||||
elif name == "hilserl_classifier":
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
return Classifier
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -85,8 +76,6 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return PI0Config(**kwargs)
|
||||
elif policy_type == "pi0fast":
|
||||
return PI0FASTConfig(**kwargs)
|
||||
elif policy_type == "hilserl_classifier":
|
||||
return ClassifierConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig
|
||||
from lerobot.common.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass(name="hilserl_classifier")
|
||||
@dataclass
|
||||
class ClassifierConfig(PreTrainedConfig):
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
name: str = "hilserl_classifier"
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
learning_rate: float = 1e-4
|
||||
normalization_mode = None
|
||||
# output_features: Dict[str, PolicyFeature] = field(
|
||||
# default_factory=lambda: {"next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,))}
|
||||
# )
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> List | None:
|
||||
return None
|
||||
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.learning_rate,
|
||||
weight_decay=0.01,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate feature configurations."""
|
||||
# Classifier doesn't need specific feature validation
|
||||
pass
|
||||
@@ -1,237 +0,0 @@
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.constants import OBS_IMAGE
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import (
|
||||
ClassifierConfig,
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logits: Tensor,
|
||||
probabilities: Optional[Tensor] = None,
|
||||
hidden_states: Optional[Tensor] = None,
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(PreTrainedPolicy):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
name = "hilserl_classifier"
|
||||
config_class = ClassifierConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ClassifierConfig,
|
||||
dataset_stats: Dict[str, Dict[str, Tensor]] | None = None,
|
||||
):
|
||||
from transformers import AutoModel
|
||||
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Initialize normalization (standardized with the policy framework)
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
|
||||
# Set up encoder
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(
|
||||
self.config.hidden_dim,
|
||||
1 if self.config.num_classes == 2 else self.config.num_classes,
|
||||
),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(x)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(x)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def extract_images_and_labels(self, batch: Dict[str, Tensor]) -> Tuple[list, Tensor]:
|
||||
"""Extract image tensors and label tensors from batch."""
|
||||
# Find image keys in input features
|
||||
image_keys = [key for key in self.config.input_features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
# Extract the images and labels
|
||||
images = [batch[key] for key in image_keys]
|
||||
labels = batch["next.reward"]
|
||||
|
||||
return images, labels
|
||||
|
||||
def predict(self, xs: list) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier for inference."""
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
|
||||
"""Standard forward pass for training compatible with train.py."""
|
||||
# Normalize inputs if needed
|
||||
batch = self.normalize_inputs(batch)
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Extract images and labels
|
||||
images, labels = self.extract_images_and_labels(batch)
|
||||
|
||||
# Get predictions
|
||||
outputs = self.predict(images)
|
||||
|
||||
# Calculate loss
|
||||
if self.config.num_classes == 2:
|
||||
# Binary classification
|
||||
loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels)
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
# Multi-class classification
|
||||
loss = nn.functional.cross_entropy(outputs.logits, labels.long())
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
|
||||
# Calculate accuracy for logging
|
||||
correct = (predictions == labels).sum().item()
|
||||
total = labels.size(0)
|
||||
accuracy = 100 * correct / total
|
||||
|
||||
# Return loss and metrics for logging
|
||||
output_dict = {
|
||||
"accuracy": accuracy,
|
||||
"correct": correct,
|
||||
"total": total,
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def predict_reward(self, batch, threshold=0.6):
|
||||
"""Legacy method for compatibility."""
|
||||
images, _ = self.extract_images_and_labels(batch)
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.predict(images).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.predict(images).probabilities, dim=1)
|
||||
|
||||
# Methods required by PreTrainedPolicy abstract class
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return optimizer parameters for the policy."""
|
||||
return {
|
||||
"params": self.parameters(),
|
||||
"lr": getattr(self.config, "learning_rate", 1e-4),
|
||||
"weight_decay": getattr(self.config, "weight_decay", 0.01),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset any stateful components (required by PreTrainedPolicy)."""
|
||||
# Classifier doesn't have stateful components that need resetting
|
||||
pass
|
||||
|
||||
def select_action(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||
"""Return action (class prediction) based on input observation."""
|
||||
images, _ = self.extract_images_and_labels(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = self.predict(images)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
# For binary classification return 0 or 1
|
||||
return (outputs.probabilities > 0.5).float()
|
||||
else:
|
||||
# For multi-class return the predicted class
|
||||
return torch.argmax(outputs.probabilities, dim=1)
|
||||
@@ -79,46 +79,28 @@ def create_stats_buffers(
|
||||
)
|
||||
|
||||
# TODO(aliberts, rcadene): harmonize this to only use one framework (np or torch)
|
||||
if stats and key in stats:
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
if "mean" not in stats[key] or "std" not in stats[key]:
|
||||
raise ValueError(
|
||||
f"Missing 'mean' or 'std' in stats for key {key} with MEAN_STD normalization"
|
||||
)
|
||||
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected for 'mean', but type is '{type_}' instead."
|
||||
)
|
||||
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
if "min" not in stats[key] or "max" not in stats[key]:
|
||||
raise ValueError(
|
||||
f"Missing 'min' or 'max' in stats for key {key} with MIN_MAX normalization"
|
||||
)
|
||||
|
||||
if isinstance(stats[key]["min"], np.ndarray):
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
elif isinstance(stats[key]["min"], torch.Tensor):
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
else:
|
||||
type_ = type(stats[key]["min"])
|
||||
raise ValueError(
|
||||
f"np.ndarray or torch.Tensor expected for 'min', but type is '{type_}' instead."
|
||||
)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
@@ -167,13 +149,12 @@ class Normalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
# @torch.no_grad
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
if key not in batch:
|
||||
# FIXME(aliberts, rcadene): This might lead to silent fail!
|
||||
# NOTE: (azouitine) This continues help us for instantiation SACPolicy
|
||||
continue
|
||||
|
||||
norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY)
|
||||
@@ -242,7 +223,7 @@ class Unnormalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
# @torch.no_grad
|
||||
@torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, ft in self.features.items():
|
||||
|
||||
@@ -61,11 +61,7 @@ from lerobot.common.policies.pi0.conversion_scripts.conversion_utils import (
|
||||
)
|
||||
from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
|
||||
|
||||
PRECISIONS = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
}
|
||||
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
|
||||
@@ -48,32 +48,18 @@ def flex_attention_forward(
|
||||
|
||||
key_states = key_states[:, :, :, None, :]
|
||||
key_states = key_states.expand(
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size,
|
||||
key_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :]
|
||||
value_states = value_states.expand(
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size,
|
||||
value_states.shape[1],
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
query_states = query_states.transpose(1, 2)
|
||||
|
||||
@@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from lerobot.common.constants import ACTION, OBS_ROBOT
|
||||
from lerobot.common.constants import ACTION, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.common.policies.pi0.paligemma_with_expert import (
|
||||
@@ -69,11 +69,7 @@ from lerobot.common.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor,
|
||||
dimension: int,
|
||||
min_period: float,
|
||||
max_period: float,
|
||||
device="cpu",
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if dimension % 2 != 0:
|
||||
@@ -275,7 +271,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self.eval()
|
||||
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
@@ -307,7 +303,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT])
|
||||
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
@@ -384,7 +380,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
def prepare_language(self, batch) -> tuple[Tensor, Tensor]:
|
||||
"""Tokenize the text input"""
|
||||
device = batch[OBS_ROBOT].device
|
||||
device = batch[OBS_STATE].device
|
||||
tasks = batch["task"]
|
||||
|
||||
# PaliGemma prompt has to end with a new line
|
||||
@@ -431,7 +427,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Pad state"""
|
||||
state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim)
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
@@ -581,11 +577,7 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = create_sinusoidal_pos_embedding(
|
||||
timestep,
|
||||
self.config.proj_width,
|
||||
min_period=4e-3,
|
||||
max_period=4.0,
|
||||
device=device,
|
||||
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
|
||||
)
|
||||
time_emb = time_emb.type(dtype=dtype)
|
||||
|
||||
@@ -617,15 +609,7 @@ class PI0FlowMatching(nn.Module):
|
||||
return embs, pad_masks, att_masks
|
||||
|
||||
def forward(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
actions,
|
||||
noise=None,
|
||||
time=None,
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
@@ -671,11 +655,7 @@ class PI0FlowMatching(nn.Module):
|
||||
device = state.device
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (
|
||||
bsize,
|
||||
self.config.n_action_steps,
|
||||
self.config.max_action_dim,
|
||||
)
|
||||
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
|
||||
@@ -293,18 +293,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
# in `transformers`. (molbap)
|
||||
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
|
||||
value_states = torch.cat(
|
||||
[past_key_values[layer_idx]["value_states"], value_states],
|
||||
dim=1,
|
||||
[past_key_values[layer_idx]["value_states"], value_states], dim=1
|
||||
)
|
||||
|
||||
attention_interface = self.get_attention_interface()
|
||||
att_output = attention_interface(
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
|
||||
@@ -364,24 +358,12 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
return attention_interface
|
||||
|
||||
def flash_attention_forward(
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
raise NotImplementedError("FA2 is not implemented (yet)")
|
||||
|
||||
def eager_attention_forward(
|
||||
self,
|
||||
attention_mask,
|
||||
batch_size,
|
||||
head_dim,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
):
|
||||
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
|
||||
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
|
||||
@@ -393,31 +375,17 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
sequence_length = key_states.shape[1]
|
||||
|
||||
key_states = key_states[:, :, :, None, :].expand(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
key_states = key_states.reshape(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
value_states = value_states[:, :, :, None, :].expand(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads,
|
||||
num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
|
||||
)
|
||||
value_states = value_states.reshape(
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_key_value_heads * num_key_value_groups,
|
||||
head_dim,
|
||||
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
|
||||
)
|
||||
|
||||
# Attention here is upcasted to float32 to match the original eager implementation.
|
||||
|
||||
@@ -1,229 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.optim.optimizers import MultiAdamConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConcurrencyConfig:
|
||||
actor: str = "threads"
|
||||
learner: str = "threads"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorLearnerConfig:
|
||||
learner_host: str = "127.0.0.1"
|
||||
learner_port: int = 50051
|
||||
policy_parameters_push_frequency: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class CriticNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
final_activation: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActorNetworkConfig:
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
activate_final: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyConfig:
|
||||
use_tanh_squash: bool = True
|
||||
log_std_min: float = 1e-5
|
||||
log_std_max: float = 10.0
|
||||
init_final: float = 0.05
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACConfig(PreTrainedConfig):
|
||||
"""Soft Actor-Critic (SAC) configuration.
|
||||
|
||||
SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy
|
||||
reinforcement learning framework. It learns a policy and a Q-function simultaneously
|
||||
using experience collected from the environment.
|
||||
|
||||
This configuration class contains all the parameters needed to define a SAC agent,
|
||||
including network architectures, optimization settings, and algorithm-specific
|
||||
hyperparameters.
|
||||
|
||||
Args:
|
||||
actor_network: Configuration for the actor network architecture.
|
||||
critic_network: Configuration for the critic network architecture.
|
||||
policy: Configuration for the policy parameters.
|
||||
n_obs_steps: Number of observation steps to consider.
|
||||
normalization_mapping: Mapping of feature types to normalization modes.
|
||||
dataset_stats: Statistics for normalizing different types of inputs.
|
||||
input_features: Dictionary of input features with their types and shapes.
|
||||
output_features: Dictionary of output features with their types and shapes.
|
||||
camera_number: Number of cameras used for visual observations.
|
||||
device: Device to run the model on (e.g., "cuda", "cpu").
|
||||
storage_device: Device to store the model on.
|
||||
vision_encoder_name: Name of the vision encoder model.
|
||||
freeze_vision_encoder: Whether to freeze the vision encoder during training.
|
||||
image_encoder_hidden_dim: Hidden dimension size for the image encoder.
|
||||
shared_encoder: Whether to use a shared encoder for actor and critic.
|
||||
num_discrete_actions: Number of discrete actions, eg for gripper actions.
|
||||
image_embedding_pooling_dim: Dimension of the image embedding pooling.
|
||||
concurrency: Configuration for concurrency settings.
|
||||
actor_learner: Configuration for actor-learner architecture.
|
||||
online_steps: Number of steps for online training.
|
||||
online_env_seed: Seed for the online environment.
|
||||
online_buffer_capacity: Capacity of the online replay buffer.
|
||||
offline_buffer_capacity: Capacity of the offline replay buffer.
|
||||
async_prefetch: Whether to use asynchronous prefetching for the buffers.
|
||||
online_step_before_learning: Number of steps before learning starts.
|
||||
policy_update_freq: Frequency of policy updates.
|
||||
discount: Discount factor for the SAC algorithm.
|
||||
temperature_init: Initial temperature value.
|
||||
num_critics: Number of critics in the ensemble.
|
||||
num_subsample_critics: Number of subsampled critics for training.
|
||||
critic_lr: Learning rate for the critic network.
|
||||
actor_lr: Learning rate for the actor network.
|
||||
temperature_lr: Learning rate for the temperature parameter.
|
||||
critic_target_update_weight: Weight for the critic target update.
|
||||
utd_ratio: Update-to-data ratio for the UTD algorithm.
|
||||
state_encoder_hidden_dim: Hidden dimension size for the state encoder.
|
||||
latent_dim: Dimension of the latent space.
|
||||
target_entropy: Target entropy for the SAC algorithm.
|
||||
use_backup_entropy: Whether to use backup entropy for the SAC algorithm.
|
||||
grad_clip_norm: Gradient clipping norm for the SAC algorithm.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ENV": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
"observation.state": {
|
||||
"min": [0.0, 0.0],
|
||||
"max": [1.0, 1.0],
|
||||
},
|
||||
"action": {
|
||||
"min": [0.0, 0.0, 0.0],
|
||||
"max": [1.0, 1.0, 1.0],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Architecture specifics
|
||||
camera_number: int = 1
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
# Set to "helper2424/resnet10" for hil serl
|
||||
vision_encoder_name: str | None = None
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
num_discrete_actions: int | None = None
|
||||
image_embedding_pooling_dim: int = 8
|
||||
|
||||
# Training parameter
|
||||
online_steps: int = 1000000
|
||||
online_env_seed: int = 10000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
async_prefetch: bool = False
|
||||
online_step_before_learning: int = 100
|
||||
policy_update_freq: int = 1
|
||||
|
||||
# SAC algorithm parameters
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
grad_clip_norm: float = 40.0
|
||||
|
||||
# Network configuration
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig)
|
||||
policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig)
|
||||
grasp_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Any validation specific to SAC configuration
|
||||
|
||||
def get_optimizer_preset(self) -> MultiAdamConfig:
|
||||
return MultiAdamConfig(
|
||||
weight_decay=0.0,
|
||||
optimizer_groups={
|
||||
"actor": {"lr": self.actor_lr},
|
||||
"critic": {"lr": self.critic_lr},
|
||||
"temperature": {"lr": self.temperature_lr},
|
||||
},
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
has_image = any(key.startswith("observation.image") for key in self.input_features)
|
||||
has_state = "observation.state" in self.input_features
|
||||
|
||||
if not (has_state or has_image):
|
||||
raise ValueError(
|
||||
"You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features"
|
||||
)
|
||||
|
||||
if "action" not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def image_features(self) -> list[str]:
|
||||
return [key for key in self.input_features if "image" in key]
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return None # SAC typically predicts one action at a time
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,15 +35,11 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.constants import OBS_ENV, OBS_ROBOT
|
||||
from lerobot.common.constants import OBS_ENV_STATE, OBS_STATE
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
|
||||
|
||||
class TDMPCPolicy(PreTrainedPolicy):
|
||||
@@ -67,11 +63,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
config_class = TDMPCConfig
|
||||
name = "tdmpc"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TDMPCConfig,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
@@ -197,20 +189,13 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(
|
||||
z,
|
||||
"b d -> n b d",
|
||||
n=self.config.n_gaussian_samples + self.config.n_pi_samples,
|
||||
)
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon,
|
||||
batch_size,
|
||||
self.config.action_feature.shape[0],
|
||||
device=device,
|
||||
self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
@@ -306,10 +291,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
if self.config.q_ensemble_size > 2:
|
||||
G += (
|
||||
running_discount
|
||||
* torch.min(
|
||||
terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))],
|
||||
dim=0,
|
||||
)[0]
|
||||
* torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[
|
||||
0
|
||||
]
|
||||
)
|
||||
else:
|
||||
G += running_discount * torch.min(terminal_values, dim=0)[0]
|
||||
@@ -345,10 +329,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
# Apply random image augmentations.
|
||||
if self.config.image_features and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(
|
||||
random_shifts_aug,
|
||||
max_random_shift_ratio=self.config.max_random_shift_ratio,
|
||||
),
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
@@ -572,10 +553,7 @@ class TDMPCTOLD(nn.Module):
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.Linear(
|
||||
config.latent_dim + config.action_feature.shape[0],
|
||||
config.mlp_dim,
|
||||
),
|
||||
nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim),
|
||||
nn.LayerNorm(config.mlp_dim),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.mlp_dim, config.mlp_dim),
|
||||
@@ -724,26 +702,11 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
5,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
config.image_encoder_hidden_dim,
|
||||
config.image_encoder_hidden_dim,
|
||||
3,
|
||||
stride=2,
|
||||
),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_shape = (1, *next(iter(config.image_features.values())).shape)
|
||||
@@ -786,14 +749,13 @@ class TDMPCObservationEncoder(nn.Module):
|
||||
if self.config.image_features:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(
|
||||
self.image_enc_layers,
|
||||
obs_dict[next(iter(self.config.image_features))],
|
||||
self.image_enc_layers, obs_dict[next(iter(self.config.image_features))]
|
||||
)
|
||||
)
|
||||
if self.config.env_state_feature:
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV]))
|
||||
feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE]))
|
||||
if self.config.robot_state_feature:
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT]))
|
||||
feat.append(self.state_enc_layers(obs_dict[OBS_STATE]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
@@ -834,9 +796,7 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float):
|
||||
"""Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param."""
|
||||
for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True):
|
||||
for (n_p_ema, p_ema), (n_p, p) in zip(
|
||||
ema_module.named_parameters(recurse=False),
|
||||
module.named_parameters(recurse=False),
|
||||
strict=True,
|
||||
ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True
|
||||
):
|
||||
assert n_p_ema == n_p, "Parameter names don't match for EMA model update"
|
||||
if isinstance(p, dict):
|
||||
|
||||
@@ -193,12 +193,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(
|
||||
range(
|
||||
1 - self.n_obs_steps,
|
||||
self.n_action_pred_token + self.action_chunk_size - 1,
|
||||
)
|
||||
)
|
||||
return list(range(1 - self.n_obs_steps, self.n_action_pred_token + self.action_chunk_size - 1))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -29,11 +29,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.policies.utils import (
|
||||
get_device_from_parameters,
|
||||
get_output_shape,
|
||||
populate_queues,
|
||||
)
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues
|
||||
from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.common.policies.vqbet.vqbet_utils import GPT, ResidualVQ
|
||||
|
||||
@@ -328,8 +324,7 @@ class VQBeTModel(nn.Module):
|
||||
|
||||
# To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT.
|
||||
self.state_projector = MLP(
|
||||
config.robot_state_feature.shape[0],
|
||||
hidden_channels=[self.config.gpt_input_dim],
|
||||
config.robot_state_feature.shape[0], hidden_channels=[self.config.gpt_input_dim]
|
||||
)
|
||||
self.rgb_feature_projector = MLP(
|
||||
self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim]
|
||||
@@ -359,11 +354,7 @@ class VQBeTModel(nn.Module):
|
||||
)
|
||||
# Separate batch and sequence dims.
|
||||
img_features = einops.rearrange(
|
||||
img_features,
|
||||
"(b s n) ... -> b s n ...",
|
||||
b=batch_size,
|
||||
s=n_obs_steps,
|
||||
n=self.num_images,
|
||||
img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images
|
||||
)
|
||||
|
||||
# Arrange prior and current observation step tokens as shown in the class docstring.
|
||||
@@ -400,11 +391,7 @@ class VQBeTModel(nn.Module):
|
||||
# Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional).
|
||||
if len_additional_action_token > 0:
|
||||
features = torch.cat(
|
||||
[
|
||||
features[:, historical_act_pred_index],
|
||||
features[:, -len_additional_action_token:],
|
||||
],
|
||||
dim=1,
|
||||
[features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1
|
||||
)
|
||||
else:
|
||||
features = features[:, historical_act_pred_index]
|
||||
@@ -527,13 +514,7 @@ class VQBeTHead(nn.Module):
|
||||
|
||||
cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin(
|
||||
torch.cat(
|
||||
(
|
||||
x,
|
||||
F.one_hot(
|
||||
sampled_primary_centers,
|
||||
num_classes=self.config.vqvae_n_embed,
|
||||
),
|
||||
),
|
||||
(x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)),
|
||||
axis=1,
|
||||
)
|
||||
)
|
||||
@@ -551,9 +532,7 @@ class VQBeTHead(nn.Module):
|
||||
else:
|
||||
cbet_logits = self.map_to_cbet_preds_bin(x)
|
||||
cbet_logits = einops.rearrange(
|
||||
cbet_logits,
|
||||
"(NT) (G C) -> (NT) G C",
|
||||
G=self.vqvae_model.vqvae_num_layers,
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
@@ -751,9 +730,7 @@ class VQBeTRgbEncoder(nn.Module):
|
||||
|
||||
|
||||
def _replace_submodules(
|
||||
root_module: nn.Module,
|
||||
predicate: Callable[[nn.Module], bool],
|
||||
func: Callable[[nn.Module], nn.Module],
|
||||
root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module]
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -377,10 +377,7 @@ class ResidualVQ(nn.Module):
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
VectorQuantize(
|
||||
dim=codebook_dim,
|
||||
codebook_dim=codebook_dim,
|
||||
accept_image_fmap=accept_image_fmap,
|
||||
**kwargs,
|
||||
dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs
|
||||
)
|
||||
for _ in range(num_quantizers)
|
||||
]
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import (
|
||||
CameraConfig,
|
||||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
|
||||
|
||||
# Defines a camera type
|
||||
class Camera(Protocol):
|
||||
def connect(self): ...
|
||||
def read(self, temporary_color: str | None = None) -> np.ndarray: ...
|
||||
def async_read(self) -> np.ndarray: ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]:
|
||||
cameras = {}
|
||||
|
||||
for key, cfg in camera_configs.items():
|
||||
if cfg.type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
cameras[key] = OpenCVCamera(cfg)
|
||||
|
||||
elif cfg.type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
cameras[key] = IntelRealSenseCamera(cfg)
|
||||
else:
|
||||
raise ValueError(f"The camera type '{cfg.type}' is not valid.")
|
||||
|
||||
return cameras
|
||||
|
||||
|
||||
def make_camera(camera_type, **kwargs) -> Camera:
|
||||
if camera_type == "opencv":
|
||||
from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
|
||||
|
||||
config = OpenCVCameraConfig(**kwargs)
|
||||
return OpenCVCamera(config)
|
||||
|
||||
elif camera_type == "intelrealsense":
|
||||
from lerobot.common.robot_devices.cameras.intelrealsense import (
|
||||
IntelRealSenseCamera,
|
||||
)
|
||||
|
||||
config = IntelRealSenseCameraConfig(**kwargs)
|
||||
return IntelRealSenseCamera(config)
|
||||
|
||||
else:
|
||||
raise ValueError(f"The camera type '{camera_type}' is not valid.")
|
||||
@@ -1,881 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 2.0
|
||||
BAUDRATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# The following bounds define the lower and upper joints range (after calibration).
|
||||
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
|
||||
# which corresponds to a half rotation on the left and half rotation on the right.
|
||||
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
|
||||
# an error is raised.
|
||||
LOWER_BOUND_DEGREE = -270
|
||||
UPPER_BOUND_DEGREE = 270
|
||||
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
|
||||
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
|
||||
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
|
||||
# [-10, 110] until an error is raised.
|
||||
LOWER_BOUND_LINEAR = -10
|
||||
UPPER_BOUND_LINEAR = 110
|
||||
|
||||
HALF_TURN_DEGREE = 180
|
||||
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270
|
||||
# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150
|
||||
|
||||
# data_name: (address, size_byte)
|
||||
X_SERIES_CONTROL_TABLE = {
|
||||
"Model_Number": (0, 2),
|
||||
"Model_Information": (2, 4),
|
||||
"Firmware_Version": (6, 1),
|
||||
"ID": (7, 1),
|
||||
"Baud_Rate": (8, 1),
|
||||
"Return_Delay_Time": (9, 1),
|
||||
"Drive_Mode": (10, 1),
|
||||
"Operating_Mode": (11, 1),
|
||||
"Secondary_ID": (12, 1),
|
||||
"Protocol_Type": (13, 1),
|
||||
"Homing_Offset": (20, 4),
|
||||
"Moving_Threshold": (24, 4),
|
||||
"Temperature_Limit": (31, 1),
|
||||
"Max_Voltage_Limit": (32, 2),
|
||||
"Min_Voltage_Limit": (34, 2),
|
||||
"PWM_Limit": (36, 2),
|
||||
"Current_Limit": (38, 2),
|
||||
"Acceleration_Limit": (40, 4),
|
||||
"Velocity_Limit": (44, 4),
|
||||
"Max_Position_Limit": (48, 4),
|
||||
"Min_Position_Limit": (52, 4),
|
||||
"Shutdown": (63, 1),
|
||||
"Torque_Enable": (64, 1),
|
||||
"LED": (65, 1),
|
||||
"Status_Return_Level": (68, 1),
|
||||
"Registered_Instruction": (69, 1),
|
||||
"Hardware_Error_Status": (70, 1),
|
||||
"Velocity_I_Gain": (76, 2),
|
||||
"Velocity_P_Gain": (78, 2),
|
||||
"Position_D_Gain": (80, 2),
|
||||
"Position_I_Gain": (82, 2),
|
||||
"Position_P_Gain": (84, 2),
|
||||
"Feedforward_2nd_Gain": (88, 2),
|
||||
"Feedforward_1st_Gain": (90, 2),
|
||||
"Bus_Watchdog": (98, 1),
|
||||
"Goal_PWM": (100, 2),
|
||||
"Goal_Current": (102, 2),
|
||||
"Goal_Velocity": (104, 4),
|
||||
"Profile_Acceleration": (108, 4),
|
||||
"Profile_Velocity": (112, 4),
|
||||
"Goal_Position": (116, 4),
|
||||
"Realtime_Tick": (120, 2),
|
||||
"Moving": (122, 1),
|
||||
"Moving_Status": (123, 1),
|
||||
"Present_PWM": (124, 2),
|
||||
"Present_Current": (126, 2),
|
||||
"Present_Velocity": (128, 4),
|
||||
"Present_Position": (132, 4),
|
||||
"Velocity_Trajectory": (136, 4),
|
||||
"Position_Trajectory": (140, 4),
|
||||
"Present_Input_Voltage": (144, 2),
|
||||
"Present_Temperature": (146, 1),
|
||||
}
|
||||
|
||||
X_SERIES_BAUDRATE_TABLE = {
|
||||
0: 9_600,
|
||||
1: 57_600,
|
||||
2: 115_200,
|
||||
3: 1_000_000,
|
||||
4: 2_000_000,
|
||||
5: 3_000_000,
|
||||
6: 4_000_000,
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"x_series": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m077": X_SERIES_CONTROL_TABLE,
|
||||
"xl330-m288": X_SERIES_CONTROL_TABLE,
|
||||
"xl430-w250": X_SERIES_CONTROL_TABLE,
|
||||
"xm430-w350": X_SERIES_CONTROL_TABLE,
|
||||
"xm540-w270": X_SERIES_CONTROL_TABLE,
|
||||
"xc430-w150": X_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"x_series": 4096,
|
||||
"xl330-m077": 4096,
|
||||
"xl330-m288": 4096,
|
||||
"xl430-w250": 4096,
|
||||
"xm430-w350": 4096,
|
||||
"xm540-w270": 4096,
|
||||
"xc430-w150": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"x_series": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m077": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl330-m288": X_SERIES_BAUDRATE_TABLE,
|
||||
"xl430-w250": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm430-w350": X_SERIES_BAUDRATE_TABLE,
|
||||
"xm540-w270": X_SERIES_BAUDRATE_TABLE,
|
||||
"xc430-w150": X_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
NUM_READ_RETRY = 10
|
||||
NUM_WRITE_RETRY = 10
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
"""
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
return group_key
|
||||
|
||||
|
||||
def get_result_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
rslt_name = f"{fn_name}_{group_key}"
|
||||
return rslt_name
|
||||
|
||||
|
||||
def get_queue_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
queue_name = f"{fn_name}_{group_key}"
|
||||
return queue_name
|
||||
|
||||
|
||||
def get_log_name(var_name, fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||
return log_name
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
class DriveMode(enum.Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
class JointOutOfRangeError(Exception):
|
||||
def __init__(self, message="Joint is out of range"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DynamixelMotorsBus:
|
||||
"""
|
||||
The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
|
||||
A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
>>> Finding all available ports for the MotorBus.
|
||||
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
>>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
>>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example of usage for 1 motor connected to the bus:
|
||||
```python
|
||||
motor_name = "gripper"
|
||||
motor_index = 6
|
||||
motor_model = "xl330-m288"
|
||||
|
||||
config = DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = DynamixelMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# when done, consider disconnecting
|
||||
motors_bus.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DynamixelMotorsBusConfig,
|
||||
):
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
self.is_connected = False
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
self.port_handler = dxl.PortHandler(self.port)
|
||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
|
||||
)
|
||||
raise
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
self.port_handler = dxl.PortHandler(self.port)
|
||||
self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
||||
# a ConnectionError will be raised anyway.
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def find_motor_indices(self, possible_ids=None, num_retry=2):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
"""
|
||||
try:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
except JointOutOfRangeError as e:
|
||||
print(e)
|
||||
self.autocorrect_calibration(values, motor_names)
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
||||
values = values.astype(np.float32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an
|
||||
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31] to
|
||||
# nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048])
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range [-resolution//2, resolution//2] to
|
||||
# universal float32 centered degree range [-180, 180]
|
||||
# (e.g. 2048 / (4096 // 2) * 180 = 180)
|
||||
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
|
||||
|
||||
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
|
||||
raise JointOutOfRangeError(
|
||||
f"Wrong motor position range detected for {name}. "
|
||||
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
|
||||
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
|
||||
f"but present value is {values[i]} degree. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Rescale the present position to a nominal range [0, 100] %,
|
||||
# useful for joints with linear motions like Aloha gripper
|
||||
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
|
||||
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
|
||||
raise JointOutOfRangeError(
|
||||
f"Wrong motor position range detected for {name}. "
|
||||
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
|
||||
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
|
||||
f"but present value is {values[i]} %. "
|
||||
"This might be due to a cable connection issue creating an artificial jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
|
||||
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
|
||||
|
||||
Known issues:
|
||||
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
|
||||
#2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha).
|
||||
#3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration
|
||||
or by human error during manual calibration.
|
||||
|
||||
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
|
||||
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
|
||||
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
|
||||
|
||||
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
||||
values = values.astype(np.float32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an
|
||||
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution
|
||||
upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
|
||||
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
if low_factor < upp_factor:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
# Convert from nominal 0-centered degree range [-180, 180] to
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
# Remove drive mode, which is the rotation direction of the motor, to come back to
|
||||
# actual motor rotation direction which can be arbitrary.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Convert from nominal lnear range of [0, 100] % to
|
||||
# actual motor range of values which can be arbitrary.
|
||||
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = group.txRxPacket()
|
||||
if comm == dxl.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != dxl.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = dxl.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
|
||||
for _ in range(NUM_READ_RETRY):
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm == dxl.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != dxl.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = group.txPacket()
|
||||
if comm == dxl.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != dxl.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_names)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
self.group_writers[group_key].changeParam(idx, data)
|
||||
|
||||
comm = self.group_writers[group_key].txPacket()
|
||||
if comm != dxl.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
# log the utc time when the write has been completed
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
|
||||
if self.port_handler is not None:
|
||||
self.port_handler.closePort()
|
||||
self.port_handler = None
|
||||
|
||||
self.packet_handler = None
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
@@ -1,906 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
|
||||
PROTOCOL_VERSION = 0
|
||||
BAUDRATE = 1_000_000
|
||||
TIMEOUT_MS = 1000
|
||||
|
||||
MAX_ID_RANGE = 252
|
||||
|
||||
# The following bounds define the lower and upper joints range (after calibration).
|
||||
# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees
|
||||
# which corresponds to a half rotation on the left and half rotation on the right.
|
||||
# Some joints might require higher range, so we allow up to [-270, 270] degrees until
|
||||
# an error is raised.
|
||||
LOWER_BOUND_DEGREE = -270
|
||||
UPPER_BOUND_DEGREE = 270
|
||||
# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper),
|
||||
# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully
|
||||
# closed, and 100% is fully open. To account for slight calibration issue, we allow up to
|
||||
# [-10, 110] until an error is raised.
|
||||
LOWER_BOUND_LINEAR = -10
|
||||
UPPER_BOUND_LINEAR = 110
|
||||
|
||||
HALF_TURN_DEGREE = 180
|
||||
|
||||
|
||||
# See this link for STS3215 Memory Table:
|
||||
# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true
|
||||
# data_name: (address, size_byte)
|
||||
SCS_SERIES_CONTROL_TABLE = {
|
||||
"Model": (3, 2),
|
||||
"ID": (5, 1),
|
||||
"Baud_Rate": (6, 1),
|
||||
"Return_Delay": (7, 1),
|
||||
"Response_Status_Level": (8, 1),
|
||||
"Min_Angle_Limit": (9, 2),
|
||||
"Max_Angle_Limit": (11, 2),
|
||||
"Max_Temperature_Limit": (13, 1),
|
||||
"Max_Voltage_Limit": (14, 1),
|
||||
"Min_Voltage_Limit": (15, 1),
|
||||
"Max_Torque_Limit": (16, 2),
|
||||
"Phase": (18, 1),
|
||||
"Unloading_Condition": (19, 1),
|
||||
"LED_Alarm_Condition": (20, 1),
|
||||
"P_Coefficient": (21, 1),
|
||||
"D_Coefficient": (22, 1),
|
||||
"I_Coefficient": (23, 1),
|
||||
"Minimum_Startup_Force": (24, 2),
|
||||
"CW_Dead_Zone": (26, 1),
|
||||
"CCW_Dead_Zone": (27, 1),
|
||||
"Protection_Current": (28, 2),
|
||||
"Angular_Resolution": (30, 1),
|
||||
"Offset": (31, 2),
|
||||
"Mode": (33, 1),
|
||||
"Protective_Torque": (34, 1),
|
||||
"Protection_Time": (35, 1),
|
||||
"Overload_Torque": (36, 1),
|
||||
"Speed_closed_loop_P_proportional_coefficient": (37, 1),
|
||||
"Over_Current_Protection_Time": (38, 1),
|
||||
"Velocity_closed_loop_I_integral_coefficient": (39, 1),
|
||||
"Torque_Enable": (40, 1),
|
||||
"Acceleration": (41, 1),
|
||||
"Goal_Position": (42, 2),
|
||||
"Goal_Time": (44, 2),
|
||||
"Goal_Speed": (46, 2),
|
||||
"Torque_Limit": (48, 2),
|
||||
"Lock": (55, 1),
|
||||
"Present_Position": (56, 2),
|
||||
"Present_Speed": (58, 2),
|
||||
"Present_Load": (60, 2),
|
||||
"Present_Voltage": (62, 1),
|
||||
"Present_Temperature": (63, 1),
|
||||
"Status": (65, 1),
|
||||
"Moving": (66, 1),
|
||||
"Present_Current": (69, 2),
|
||||
# Not in the Memory Table
|
||||
"Maximum_Acceleration": (85, 2),
|
||||
}
|
||||
|
||||
SCS_SERIES_BAUDRATE_TABLE = {
|
||||
0: 1_000_000,
|
||||
1: 500_000,
|
||||
2: 250_000,
|
||||
3: 128_000,
|
||||
4: 115_200,
|
||||
5: 57_600,
|
||||
6: 38_400,
|
||||
7: 19_200,
|
||||
}
|
||||
|
||||
CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"]
|
||||
|
||||
|
||||
MODEL_CONTROL_TABLE = {
|
||||
"scs_series": SCS_SERIES_CONTROL_TABLE,
|
||||
"sts3215": SCS_SERIES_CONTROL_TABLE,
|
||||
}
|
||||
|
||||
MODEL_RESOLUTION = {
|
||||
"scs_series": 4096,
|
||||
"sts3215": 4096,
|
||||
}
|
||||
|
||||
MODEL_BAUDRATE_TABLE = {
|
||||
"scs_series": SCS_SERIES_BAUDRATE_TABLE,
|
||||
"sts3215": SCS_SERIES_BAUDRATE_TABLE,
|
||||
}
|
||||
|
||||
# High number of retries is needed for feetech compared to dynamixel motors.
|
||||
NUM_READ_RETRY = 20
|
||||
NUM_WRITE_RETRY = 20
|
||||
|
||||
|
||||
def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray:
|
||||
"""This function converts the degree range to the step range for indicating motors rotation.
|
||||
It assumes a motor achieves a full rotation by going from -180 degree position to +180.
|
||||
The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation.
|
||||
"""
|
||||
resolutions = [MODEL_RESOLUTION[model] for model in models]
|
||||
steps = degrees / 180 * np.array(resolutions) / 2
|
||||
steps = steps.astype(int)
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
import scservo_sdk as scs
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_LOBYTE(scs.SCS_HIWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_HIWORD(value)),
|
||||
]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def get_group_sync_key(data_name, motor_names):
|
||||
group_key = f"{data_name}_" + "_".join(motor_names)
|
||||
return group_key
|
||||
|
||||
|
||||
def get_result_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
rslt_name = f"{fn_name}_{group_key}"
|
||||
return rslt_name
|
||||
|
||||
|
||||
def get_queue_name(fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
queue_name = f"{fn_name}_{group_key}"
|
||||
return queue_name
|
||||
|
||||
|
||||
def get_log_name(var_name, fn_name, data_name, motor_names):
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
log_name = f"{var_name}_{fn_name}_{group_key}"
|
||||
return log_name
|
||||
|
||||
|
||||
def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
if len(set(all_bytes)) != 1:
|
||||
raise NotImplementedError(
|
||||
f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer."
|
||||
)
|
||||
|
||||
|
||||
class TorqueMode(enum.Enum):
|
||||
ENABLED = 1
|
||||
DISABLED = 0
|
||||
|
||||
|
||||
class DriveMode(enum.Enum):
|
||||
NON_INVERTED = 0
|
||||
INVERTED = 1
|
||||
|
||||
|
||||
class CalibrationMode(enum.Enum):
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
DEGREE = 0
|
||||
# Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100]
|
||||
LINEAR = 1
|
||||
|
||||
|
||||
class JointOutOfRangeError(Exception):
|
||||
def __init__(self, message="Joint is out of range"):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class FeetechMotorsBus:
|
||||
"""
|
||||
The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on
|
||||
the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20).
|
||||
|
||||
A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)).
|
||||
To find the port, you can run our utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
>>> Finding all available ports for the MotorsBus.
|
||||
>>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
>>> Remove the usb cable from your FeetechMotorsBus and press Enter when done.
|
||||
>>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751.
|
||||
>>> Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example of usage for 1 motor connected to the bus:
|
||||
```python
|
||||
motor_name = "gripper"
|
||||
motor_index = 6
|
||||
motor_model = "sts3215"
|
||||
|
||||
config = FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={motor_name: (motor_index, motor_model)},
|
||||
)
|
||||
motors_bus = FeetechMotorsBus(config)
|
||||
motors_bus.connect()
|
||||
|
||||
position = motors_bus.read("Present_Position")
|
||||
|
||||
# move from a few motor steps as an example
|
||||
few_steps = 30
|
||||
motors_bus.write("Goal_Position", position + few_steps)
|
||||
|
||||
# when done, consider disconnecting
|
||||
motors_bus.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FeetechMotorsBusConfig,
|
||||
):
|
||||
self.port = config.port
|
||||
self.motors = config.motors
|
||||
self.mock = config.mock
|
||||
|
||||
self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE)
|
||||
self.model_resolution = deepcopy(MODEL_RESOLUTION)
|
||||
|
||||
self.port_handler = None
|
||||
self.packet_handler = None
|
||||
self.calibration = None
|
||||
self.is_connected = False
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.logs = {}
|
||||
|
||||
self.track_positions = {}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice."
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
try:
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n"
|
||||
)
|
||||
raise
|
||||
|
||||
# Allow to read and write
|
||||
self.is_connected = True
|
||||
|
||||
self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS)
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
self.port_handler = scs.PortHandler(self.port)
|
||||
self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION)
|
||||
|
||||
if not self.port_handler.openPort():
|
||||
raise OSError(f"Failed to open port '{self.port}'.")
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def are_motors_configured(self):
|
||||
# Only check the motor indices and not baudrate, since if the motor baudrates are incorrect,
|
||||
# a ConnectionError will be raised anyway.
|
||||
try:
|
||||
return (self.motor_indices == self.read("ID")).all()
|
||||
except ConnectionError as e:
|
||||
print(e)
|
||||
return False
|
||||
|
||||
def find_motor_indices(self, possible_ids=None, num_retry=2):
|
||||
if possible_ids is None:
|
||||
possible_ids = range(MAX_ID_RANGE)
|
||||
|
||||
indices = []
|
||||
for idx in tqdm.tqdm(possible_ids):
|
||||
try:
|
||||
present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0]
|
||||
except ConnectionError:
|
||||
continue
|
||||
|
||||
if idx != present_idx:
|
||||
# sanity check
|
||||
raise OSError(
|
||||
"Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged."
|
||||
)
|
||||
indices.append(idx)
|
||||
|
||||
return indices
|
||||
|
||||
def set_bus_baudrate(self, baudrate):
|
||||
present_bus_baudrate = self.port_handler.getBaudRate()
|
||||
if present_bus_baudrate != baudrate:
|
||||
print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.")
|
||||
self.port_handler.setBaudRate(baudrate)
|
||||
|
||||
if self.port_handler.getBaudRate() != baudrate:
|
||||
raise OSError("Failed to write bus baud rate.")
|
||||
|
||||
@property
|
||||
def motor_names(self) -> list[str]:
|
||||
return list(self.motors.keys())
|
||||
|
||||
@property
|
||||
def motor_models(self) -> list[str]:
|
||||
return [model for _, model in self.motors.values()]
|
||||
|
||||
@property
|
||||
def motor_indices(self) -> list[int]:
|
||||
return [idx for idx, _ in self.motors.values()]
|
||||
|
||||
def set_calibration(self, calibration: dict[str, list]):
|
||||
self.calibration = calibration
|
||||
|
||||
def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct.
|
||||
|
||||
For more info, see docstring of `apply_calibration` and `autocorrect_calibration`.
|
||||
"""
|
||||
try:
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
except JointOutOfRangeError as e:
|
||||
print(e)
|
||||
self.autocorrect_calibration(values, motor_names)
|
||||
values = self.apply_calibration(values, motor_names)
|
||||
return values
|
||||
|
||||
def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with
|
||||
a "zero position" at 0 degree.
|
||||
|
||||
Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor
|
||||
rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range.
|
||||
|
||||
Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation
|
||||
when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and
|
||||
at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830,
|
||||
or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor.
|
||||
To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work
|
||||
in the centered nominal degree range ]-180, 180[.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
||||
values = values.astype(np.float32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
# Update direction of rotation of the motor to match between leader and follower.
|
||||
# In fact, the motor of the leader for a given joint can be assembled in an
|
||||
# opposite direction in term of rotation than the motor of the follower on the same joint.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from range [-2**31, 2**31[ to
|
||||
# nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[)
|
||||
values[i] += homing_offset
|
||||
|
||||
# Convert from range ]-resolution, resolution[ to
|
||||
# universal float32 centered degree range ]-180, 180[
|
||||
values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE
|
||||
|
||||
if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE):
|
||||
raise JointOutOfRangeError(
|
||||
f"Wrong motor position range detected for {name}. "
|
||||
f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), "
|
||||
f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, "
|
||||
f"but present value is {values[i]} degree. "
|
||||
"This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Rescale the present position to a nominal range [0, 100] %,
|
||||
# useful for joints with linear motions like Aloha gripper
|
||||
values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
|
||||
if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR):
|
||||
raise JointOutOfRangeError(
|
||||
f"Wrong motor position range detected for {name}. "
|
||||
f"Expected to be in nominal range of [0, 100] % (a full linear translation), "
|
||||
f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, "
|
||||
f"but present value is {values[i]} %. "
|
||||
"This might be due to a cable connection issue creating an artificial jump in motor values. "
|
||||
"You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""This function automatically detects issues with values of motors after calibration, and correct for these issues.
|
||||
|
||||
Some motors might have values outside of expected maximum bounds after calibration.
|
||||
For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given
|
||||
a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position.
|
||||
|
||||
Known issues:
|
||||
#1: Motor value randomly shifts of a full turn, caused by hardware/connection errors.
|
||||
#2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha).
|
||||
#3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration
|
||||
or by human error during manual calibration.
|
||||
|
||||
Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn.
|
||||
Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`,
|
||||
that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue.
|
||||
|
||||
Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
"""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
# Convert from unsigned int32 original range [0, 2**32] to signed float32 range
|
||||
values = values.astype(np.float32)
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
# Convert from initial range to range [-180, 180] degrees
|
||||
calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [-180, 180] degrees
|
||||
# values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE
|
||||
# - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE
|
||||
# (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution
|
||||
low_factor = (
|
||||
-HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
upp_factor = (
|
||||
HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset
|
||||
) / resolution
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Convert from initial range to range [0, 100] in %
|
||||
calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100
|
||||
in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR)
|
||||
|
||||
# Solve this inequality to find the factor to shift the range into [0, 100] %
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100
|
||||
# values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100
|
||||
# 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
if low_factor < upp_factor:
|
||||
factor = math.ceil(low_factor)
|
||||
|
||||
if factor > upp_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
else:
|
||||
factor = math.ceil(upp_factor)
|
||||
|
||||
if factor > low_factor:
|
||||
raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]")
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees"
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
self.calibration["homing_offset"][calib_idx] += resolution * factor
|
||||
|
||||
def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None):
|
||||
"""Inverse of `apply_calibration`."""
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
calib_idx = self.calibration["motor_names"].index(name)
|
||||
calib_mode = self.calibration["calib_mode"][calib_idx]
|
||||
|
||||
if CalibrationMode[calib_mode] == CalibrationMode.DEGREE:
|
||||
drive_mode = self.calibration["drive_mode"][calib_idx]
|
||||
homing_offset = self.calibration["homing_offset"][calib_idx]
|
||||
_, model = self.motors[name]
|
||||
resolution = self.model_resolution[model]
|
||||
|
||||
# Convert from nominal 0-centered degree range [-180, 180] to
|
||||
# 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096)
|
||||
values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2)
|
||||
|
||||
# Subtract the homing offsets to come back to actual motor range of values
|
||||
# which can be arbitrary.
|
||||
values[i] -= homing_offset
|
||||
|
||||
# Remove drive mode, which is the rotation direction of the motor, to come back to
|
||||
# actual motor rotation direction which can be arbitrary.
|
||||
if drive_mode:
|
||||
values[i] *= -1
|
||||
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
start_pos = self.calibration["start_pos"][calib_idx]
|
||||
end_pos = self.calibration["end_pos"][calib_idx]
|
||||
|
||||
# Convert from nominal lnear range of [0, 100] % to
|
||||
# actual motor range of values which can be arbitrary.
|
||||
values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos
|
||||
|
||||
values = np.round(values).astype(np.int32)
|
||||
return values
|
||||
|
||||
def avoid_rotation_reset(self, values, motor_names, data_name):
|
||||
if data_name not in self.track_positions:
|
||||
self.track_positions[data_name] = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
# Assume False at initialization
|
||||
"below_zero": [False] * len(self.motor_names),
|
||||
"above_max": [False] * len(self.motor_names),
|
||||
}
|
||||
|
||||
track = self.track_positions[data_name]
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
for i, name in enumerate(motor_names):
|
||||
idx = self.motor_names.index(name)
|
||||
|
||||
if track["prev"][idx] is None:
|
||||
track["prev"][idx] = values[i]
|
||||
continue
|
||||
|
||||
# Detect a full rotation occurred
|
||||
if abs(track["prev"][idx] - values[i]) > 2048:
|
||||
# Position went below 0 and got reset to 4095
|
||||
if track["prev"][idx] < values[i]:
|
||||
# So we set negative value by adding a full rotation
|
||||
values[i] -= 4096
|
||||
|
||||
# Position went above 4095 and got reset to 0
|
||||
elif track["prev"][idx] > values[i]:
|
||||
# So we add a full rotation
|
||||
values[i] += 4096
|
||||
|
||||
track["prev"][idx] = values[i]
|
||||
|
||||
return values
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
return_list = True
|
||||
if not isinstance(motor_ids, list):
|
||||
return_list = False
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = group.txRxPacket()
|
||||
if comm == scs.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != scs.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
return values
|
||||
else:
|
||||
return values[0]
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# Very Important to flush the buffer!
|
||||
self.port_handler.ser.reset_output_buffer()
|
||||
self.port_handler.ser.reset_input_buffer()
|
||||
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
|
||||
for _ in range(NUM_READ_RETRY):
|
||||
comm = self.group_readers[group_key].txRxPacket()
|
||||
if comm == scs.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != scs.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Read failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# Convert to signed int to use range [-2048, 2048] for our motor positions.
|
||||
if data_name in CONVERT_UINT32_TO_INT32_REQUIRED:
|
||||
values = values.astype(np.int32)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED:
|
||||
values = self.avoid_rotation_reset(values, motor_names, data_name)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.apply_calibration_autocorrect(values, motor_names)
|
||||
|
||||
# log the number of seconds it took to read the data from the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# log the utc time at which the data was received
|
||||
ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
return values
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
if not isinstance(motor_ids, list):
|
||||
motor_ids = [motor_ids]
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
comm = group.txPacket()
|
||||
if comm == scs.COMM_SUCCESS:
|
||||
break
|
||||
|
||||
if comm != scs.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
def write(
|
||||
self,
|
||||
data_name,
|
||||
values: int | float | np.ndarray,
|
||||
motor_names: str | list[str] | None = None,
|
||||
):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`."
|
||||
)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
if motor_names is None:
|
||||
motor_names = self.motor_names
|
||||
|
||||
if isinstance(motor_names, str):
|
||||
motor_names = [motor_names]
|
||||
|
||||
if isinstance(values, (int, float, np.integer)):
|
||||
values = [int(values)] * len(motor_names)
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
motor_ids = []
|
||||
models = []
|
||||
for name in motor_names:
|
||||
motor_idx, model = self.motors[name]
|
||||
motor_ids.append(motor_idx)
|
||||
models.append(model)
|
||||
|
||||
if data_name in CALIBRATION_REQUIRED and self.calibration is not None:
|
||||
values = self.revert_calibration(values, motor_names)
|
||||
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = scs.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
self.group_writers[group_key].changeParam(idx, data)
|
||||
|
||||
comm = self.group_writers[group_key].txPacket()
|
||||
if comm != scs.COMM_SUCCESS:
|
||||
raise ConnectionError(
|
||||
f"Write failed due to communication error on port {self.port} for group_key {group_key}: "
|
||||
f"{self.packet_handler.getTxRxResult(comm)}"
|
||||
)
|
||||
|
||||
# log the number of seconds it took to write the data to the motors
|
||||
delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names)
|
||||
self.logs[delta_ts_name] = time.perf_counter() - start_time
|
||||
|
||||
# TODO(rcadene): should we log the time before sending the write command?
|
||||
# log the utc time when the write has been completed
|
||||
ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names)
|
||||
self.logs[ts_utc_name] = capture_timestamp_utc()
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first."
|
||||
)
|
||||
|
||||
if self.port_handler is not None:
|
||||
self.port_handler.closePort()
|
||||
self.port_handler = None
|
||||
|
||||
self.packet_handler = None
|
||||
self.group_readers = {}
|
||||
self.group_writers = {}
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
@@ -1,613 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Sequence
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.robot_devices.cameras.configs import (
|
||||
CameraConfig,
|
||||
IntelRealSenseCameraConfig,
|
||||
OpenCVCameraConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.configs import (
|
||||
DynamixelMotorsBusConfig,
|
||||
FeetechMotorsBusConfig,
|
||||
MotorsBusConfig,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
|
||||
# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction
|
||||
@dataclass
|
||||
class ManipulatorRobotConfig(RobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
||||
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
||||
# motors).
|
||||
max_relative_target: list[float] | float | None = None
|
||||
|
||||
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
||||
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mock:
|
||||
for arm in self.leader_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for arm in self.follower_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for cam in self.cameras.values():
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaRobotConfig(ManipulatorRobotConfig):
|
||||
# Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been
|
||||
# properly assembled, no manual calibration step is expected. If you need to run manual calibration,
|
||||
# simply update this path to ".cache/calibration/aloha"
|
||||
calibration_dir: str = ".cache/calibration/aloha_default"
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
# window_x
|
||||
port="/dev/ttyDXL_leader_left",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm430-w350"],
|
||||
"shoulder": [2, "xm430-w350"],
|
||||
"shoulder_shadow": [3, "xm430-w350"],
|
||||
"elbow": [4, "xm430-w350"],
|
||||
"elbow_shadow": [5, "xm430-w350"],
|
||||
"forearm_roll": [6, "xm430-w350"],
|
||||
"wrist_angle": [7, "xm430-w350"],
|
||||
"wrist_rotate": [8, "xl430-w250"],
|
||||
"gripper": [9, "xc430-w150"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
# window_x
|
||||
port="/dev/ttyDXL_leader_right",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm430-w350"],
|
||||
"shoulder": [2, "xm430-w350"],
|
||||
"shoulder_shadow": [3, "xm430-w350"],
|
||||
"elbow": [4, "xm430-w350"],
|
||||
"elbow_shadow": [5, "xm430-w350"],
|
||||
"forearm_roll": [6, "xm430-w350"],
|
||||
"wrist_angle": [7, "xm430-w350"],
|
||||
"wrist_rotate": [8, "xl430-w250"],
|
||||
"gripper": [9, "xc430-w150"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/ttyDXL_follower_left",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm540-w270"],
|
||||
"shoulder": [2, "xm540-w270"],
|
||||
"shoulder_shadow": [3, "xm540-w270"],
|
||||
"elbow": [4, "xm540-w270"],
|
||||
"elbow_shadow": [5, "xm540-w270"],
|
||||
"forearm_roll": [6, "xm540-w270"],
|
||||
"wrist_angle": [7, "xm540-w270"],
|
||||
"wrist_rotate": [8, "xm430-w350"],
|
||||
"gripper": [9, "xm430-w350"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/ttyDXL_follower_right",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"waist": [1, "xm540-w270"],
|
||||
"shoulder": [2, "xm540-w270"],
|
||||
"shoulder_shadow": [3, "xm540-w270"],
|
||||
"elbow": [4, "xm540-w270"],
|
||||
"elbow_shadow": [5, "xm540-w270"],
|
||||
"forearm_roll": [6, "xm540-w270"],
|
||||
"wrist_angle": [7, "xm540-w270"],
|
||||
"wrist_rotate": [8, "xm430-w350"],
|
||||
"gripper": [9, "xm430-w350"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"cam_high": IntelRealSenseCameraConfig(
|
||||
serial_number=128422271347,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_low": IntelRealSenseCameraConfig(
|
||||
serial_number=130322270656,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_left_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=218622272670,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"cam_right_wrist": IntelRealSenseCameraConfig(
|
||||
serial_number=130322272300,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("koch")
|
||||
@dataclass
|
||||
class KochRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: float = 35.156
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("koch_bimanual")
|
||||
@dataclass
|
||||
class KochBimanualRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/koch_bimanual"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0085511",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0031751",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl330-m077"],
|
||||
"shoulder_lift": [2, "xl330-m077"],
|
||||
"elbow_flex": [3, "xl330-m077"],
|
||||
"wrist_flex": [4, "xl330-m077"],
|
||||
"wrist_roll": [5, "xl330-m077"],
|
||||
"gripper": [6, "xl330-m077"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"left": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
"right": DynamixelMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem575E0032081",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "xl430-w250"],
|
||||
"shoulder_lift": [2, "xl430-w250"],
|
||||
"elbow_flex": [3, "xl330-m288"],
|
||||
"wrist_flex": [4, "xl330-m288"],
|
||||
"wrist_roll": [5, "xl330-m288"],
|
||||
"gripper": [6, "xl330-m288"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# ~ Koch specific settings ~
|
||||
# Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible
|
||||
# to squeeze the gripper and have it spring back to an open position on its own.
|
||||
gripper_open_degree: float = 35.156
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("moss")
|
||||
@dataclass
|
||||
class MossRobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/moss"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431091",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0076891",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100")
|
||||
@dataclass
|
||||
class So100RobotConfig(ManipulatorRobotConfig):
|
||||
calibration_dir: str = ".cache/calibration/so100"
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760433331",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem58760431631",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"laptop": OpenCVCameraConfig(
|
||||
camera_index=0,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
"phone": OpenCVCameraConfig(
|
||||
camera_index=1,
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("stretch")
|
||||
@dataclass
|
||||
class StretchRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"navigation": OpenCVCameraConfig(
|
||||
camera_index="/dev/hello-nav-head-camera",
|
||||
fps=10,
|
||||
width=1280,
|
||||
height=720,
|
||||
rotation=-90,
|
||||
),
|
||||
"head": IntelRealSenseCameraConfig(
|
||||
name="Intel RealSense D435I",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
rotation=90,
|
||||
),
|
||||
"wrist": IntelRealSenseCameraConfig(
|
||||
name="Intel RealSense D405",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "192.168.0.193"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"front": OpenCVCameraConfig(
|
||||
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
||||
),
|
||||
"wrist": OpenCVCameraConfig(
|
||||
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
"forward": "w",
|
||||
"backward": "s",
|
||||
"left": "a",
|
||||
"right": "d",
|
||||
"rotate_left": "z",
|
||||
"rotate_right": "x",
|
||||
# Speed control
|
||||
"speed_up": "r",
|
||||
"speed_down": "f",
|
||||
# quit teleop
|
||||
"quit": "q",
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
@@ -1,509 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Logic to calibrate a robot arm built with feetech motors"""
|
||||
# TODO(rcadene, aliberts): move this logic into the robot code when refactoring
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.motors.feetech import (
|
||||
CalibrationMode,
|
||||
TorqueMode,
|
||||
convert_degrees_to_steps,
|
||||
)
|
||||
from lerobot.common.robot_devices.motors.utils import MotorsBus
|
||||
|
||||
URL_TEMPLATE = (
|
||||
"https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp"
|
||||
)
|
||||
|
||||
# The following positions are provided in nominal degree range ]-180, +180[
|
||||
# For more info on these constants, see comments in the code where they get used.
|
||||
ZERO_POSITION_DEGREE = 0
|
||||
ROTATED_POSITION_DEGREE = 90
|
||||
|
||||
|
||||
def assert_drive_mode(drive_mode):
|
||||
# `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted.
|
||||
if not np.all(np.isin(drive_mode, [0, 1])):
|
||||
raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})")
|
||||
|
||||
|
||||
def apply_drive_mode(position, drive_mode):
|
||||
assert_drive_mode(drive_mode)
|
||||
# Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted,
|
||||
# to [-1, 1] with 1 indicates original rotation direction and -1 inverted.
|
||||
signed_drive_mode = -(drive_mode * 2 - 1)
|
||||
position *= signed_drive_mode
|
||||
return position
|
||||
|
||||
|
||||
def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None):
|
||||
count = 0
|
||||
while True:
|
||||
present_pos = arm.read("Present_Position", motor_name)
|
||||
if positive_direction:
|
||||
# Move +100 steps every time. Lower the steps to lower the speed at which the arm moves.
|
||||
arm.write("Goal_Position", present_pos + 100, motor_name)
|
||||
else:
|
||||
arm.write("Goal_Position", present_pos - 100, motor_name)
|
||||
|
||||
if while_move_hook is not None:
|
||||
while_move_hook()
|
||||
|
||||
present_pos = arm.read("Present_Position", motor_name).item()
|
||||
present_speed = arm.read("Present_Speed", motor_name).item()
|
||||
present_current = arm.read("Present_Current", motor_name).item()
|
||||
# present_load = arm.read("Present_Load", motor_name).item()
|
||||
# present_voltage = arm.read("Present_Voltage", motor_name).item()
|
||||
# present_temperature = arm.read("Present_Temperature", motor_name).item()
|
||||
|
||||
# print(f"{present_pos=}")
|
||||
# print(f"{present_speed=}")
|
||||
# print(f"{present_current=}")
|
||||
# print(f"{present_load=}")
|
||||
# print(f"{present_voltage=}")
|
||||
# print(f"{present_temperature=}")
|
||||
|
||||
if present_speed == 0 and present_current > 40:
|
||||
count += 1
|
||||
if count > 100 or present_current > 300:
|
||||
return present_pos
|
||||
else:
|
||||
count = 0
|
||||
|
||||
|
||||
def move_to_calibrate(
|
||||
arm,
|
||||
motor_name,
|
||||
invert_drive_mode=False,
|
||||
positive_first=True,
|
||||
in_between_move_hook=None,
|
||||
while_move_hook=None,
|
||||
):
|
||||
initial_pos = arm.read("Present_Position", motor_name)
|
||||
|
||||
if positive_first:
|
||||
p_present_pos = move_until_block(
|
||||
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
|
||||
)
|
||||
else:
|
||||
n_present_pos = move_until_block(
|
||||
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
|
||||
)
|
||||
|
||||
if in_between_move_hook is not None:
|
||||
in_between_move_hook()
|
||||
|
||||
if positive_first:
|
||||
n_present_pos = move_until_block(
|
||||
arm, motor_name, positive_direction=False, while_move_hook=while_move_hook
|
||||
)
|
||||
else:
|
||||
p_present_pos = move_until_block(
|
||||
arm, motor_name, positive_direction=True, while_move_hook=while_move_hook
|
||||
)
|
||||
|
||||
zero_pos = (n_present_pos + p_present_pos) / 2
|
||||
|
||||
calib_data = {
|
||||
"initial_pos": initial_pos,
|
||||
"homing_offset": zero_pos if invert_drive_mode else -zero_pos,
|
||||
"invert_drive_mode": invert_drive_mode,
|
||||
"drive_mode": -1 if invert_drive_mode else 0,
|
||||
"zero_pos": zero_pos,
|
||||
"start_pos": n_present_pos if invert_drive_mode else p_present_pos,
|
||||
"end_pos": p_present_pos if invert_drive_mode else n_present_pos,
|
||||
}
|
||||
return calib_data
|
||||
|
||||
|
||||
def apply_offset(calib, offset):
|
||||
calib["zero_pos"] += offset
|
||||
if calib["drive_mode"]:
|
||||
calib["homing_offset"] += offset
|
||||
else:
|
||||
calib["homing_offset"] -= offset
|
||||
return calib
|
||||
|
||||
|
||||
def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
if robot_type == "so100":
|
||||
return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type)
|
||||
elif robot_type == "moss":
|
||||
return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type)
|
||||
else:
|
||||
raise ValueError(robot_type)
|
||||
|
||||
|
||||
def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "so100" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
initial_acceleration = arm.read("Acceleration")
|
||||
arm.write("Lock", 0)
|
||||
arm.write("Acceleration", 10)
|
||||
time.sleep(1)
|
||||
|
||||
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
||||
|
||||
print(f'{arm.read("Present_Position", "elbow_flex")=}')
|
||||
|
||||
calib = {}
|
||||
|
||||
init_wf_pos = arm.read("Present_Position", "wrist_flex")
|
||||
init_sl_pos = arm.read("Present_Position", "shoulder_lift")
|
||||
init_ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||
arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex")
|
||||
arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift")
|
||||
arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate shoulder_pan")
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate gripper")
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_flex")
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
||||
|
||||
def in_between_move_hook():
|
||||
nonlocal arm, calib
|
||||
time.sleep(2)
|
||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||
sl_pos = arm.read("Present_Position", "shoulder_lift")
|
||||
arm.write("Goal_Position", ef_pos + 1024, "elbow_flex")
|
||||
arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm,
|
||||
"elbow_flex",
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
nonlocal arm, calib
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
||||
|
||||
print("Calibrate shoulder_lift")
|
||||
calib["shoulder_lift"] = move_to_calibrate(
|
||||
arm,
|
||||
"shoulder_lift",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
)
|
||||
# add an 30 steps as offset to align with body
|
||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
||||
|
||||
def while_move_hook():
|
||||
nonlocal arm, calib
|
||||
positions = {
|
||||
"shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
||||
"elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700),
|
||||
"wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800),
|
||||
"gripper": round(calib["gripper"]["end_pos"]),
|
||||
}
|
||||
arm.write("Goal_Position", list(positions.values()), list(positions.keys()))
|
||||
|
||||
arm.write(
|
||||
"Goal_Position",
|
||||
round(calib["shoulder_lift"]["zero_pos"] - 1600),
|
||||
"shoulder_lift",
|
||||
)
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex")
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex")
|
||||
time.sleep(2)
|
||||
arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper")
|
||||
time.sleep(2)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(
|
||||
arm,
|
||||
"wrist_roll",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
while_move_hook=while_move_hook,
|
||||
)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex")
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
|
||||
calib_modes = []
|
||||
for name in arm.motor_names:
|
||||
if name == "gripper":
|
||||
calib_modes.append(CalibrationMode.LINEAR.name)
|
||||
else:
|
||||
calib_modes.append(CalibrationMode.DEGREE.name)
|
||||
|
||||
calib_dict = {
|
||||
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
|
||||
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
|
||||
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
|
||||
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
|
||||
"calib_mode": calib_modes,
|
||||
"motor_names": arm.motor_names,
|
||||
}
|
||||
|
||||
# Re-enable original accerlation
|
||||
arm.write("Lock", 0)
|
||||
arm.write("Acceleration", initial_acceleration)
|
||||
time.sleep(1)
|
||||
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
if not (robot_type == "moss" and arm_type == "follower"):
|
||||
raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to initial position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# Lower the acceleration of the motors (in [0,254])
|
||||
initial_acceleration = arm.read("Acceleration")
|
||||
arm.write("Lock", 0)
|
||||
arm.write("Acceleration", 10)
|
||||
time.sleep(1)
|
||||
|
||||
arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
||||
|
||||
sl_pos = arm.read("Present_Position", "shoulder_lift")
|
||||
arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift")
|
||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||
arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex")
|
||||
time.sleep(2)
|
||||
|
||||
calib = {}
|
||||
|
||||
print("Calibrate shoulder_pan")
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate gripper")
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_flex")
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
|
||||
|
||||
wr_pos = arm.read("Present_Position", "wrist_roll")
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", wr_pos - 1024, "wrist_roll")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
|
||||
calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
|
||||
arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper")
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll")
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex")
|
||||
|
||||
def in_between_move_elbow_flex_hook():
|
||||
nonlocal arm, calib
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex")
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm,
|
||||
"elbow_flex",
|
||||
invert_drive_mode=True,
|
||||
in_between_move_hook=in_between_move_elbow_flex_hook,
|
||||
)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
|
||||
def in_between_move_shoulder_lift_hook():
|
||||
nonlocal arm, calib
|
||||
sl = arm.read("Present_Position", "shoulder_lift")
|
||||
arm.write("Goal_Position", sl - 1500, "shoulder_lift")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate shoulder_lift")
|
||||
calib["shoulder_lift"] = move_to_calibrate(
|
||||
arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook
|
||||
)
|
||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
time.sleep(1)
|
||||
arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift")
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex")
|
||||
time.sleep(2)
|
||||
|
||||
calib_modes = []
|
||||
for name in arm.motor_names:
|
||||
if name == "gripper":
|
||||
calib_modes.append(CalibrationMode.LINEAR.name)
|
||||
else:
|
||||
calib_modes.append(CalibrationMode.DEGREE.name)
|
||||
|
||||
calib_dict = {
|
||||
"homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names],
|
||||
"drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names],
|
||||
"start_pos": [calib[name]["start_pos"] for name in arm.motor_names],
|
||||
"end_pos": [calib[name]["end_pos"] for name in arm.motor_names],
|
||||
"calib_mode": calib_modes,
|
||||
"motor_names": arm.motor_names,
|
||||
}
|
||||
|
||||
# Re-enable original accerlation
|
||||
arm.write("Lock", 0)
|
||||
arm.write("Acceleration", initial_acceleration)
|
||||
time.sleep(1)
|
||||
|
||||
return calib_dict
|
||||
|
||||
|
||||
def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str):
|
||||
"""This function ensures that a neural network trained on data collected on a given robot
|
||||
can work on another robot. For instance before calibration, setting a same goal position
|
||||
for each motor of two different robots will get two very different positions. But after calibration,
|
||||
the two robots will move to the same position.To this end, this function computes the homing offset
|
||||
and the drive mode for each motor of a given robot.
|
||||
|
||||
Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps
|
||||
to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions
|
||||
being 0. During the calibration process, you will need to manually move the robot to this "zero position".
|
||||
|
||||
Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled
|
||||
in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot
|
||||
to the "rotated position".
|
||||
|
||||
After calibration, the homing offsets and drive modes are stored in a cache.
|
||||
|
||||
Example of usage:
|
||||
```python
|
||||
run_arm_calibration(arm, "so100", "left", "follower")
|
||||
```
|
||||
"""
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run calibration, the torque must be disabled on all motors.")
|
||||
|
||||
print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...")
|
||||
|
||||
print("\nMove arm to zero position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
# We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed.
|
||||
# It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will
|
||||
# correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position.
|
||||
zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Compute homing offset so that `present_position + homing_offset ~= target_position`.
|
||||
zero_pos = arm.read("Present_Position")
|
||||
homing_offset = zero_target_pos - zero_pos
|
||||
|
||||
# The rotated target position corresponds to a rotation of a quarter turn from the zero position.
|
||||
# This allows to identify the rotation direction of each motor.
|
||||
# For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction
|
||||
# is inverted. However, for the calibration being successful, we need everyone to follow the same target position.
|
||||
# Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which
|
||||
# corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view
|
||||
# of the previous motor in the kinetic chain.
|
||||
print("\nMove arm to rotated target position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated"))
|
||||
input("Press Enter to continue...")
|
||||
|
||||
rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models)
|
||||
|
||||
# Find drive mode by rotating each motor by a quarter of a turn.
|
||||
# Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0).
|
||||
rotated_pos = arm.read("Present_Position")
|
||||
drive_mode = (rotated_pos < zero_pos).astype(np.int32)
|
||||
|
||||
# Re-compute homing offset to take into account drive mode
|
||||
rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode)
|
||||
homing_offset = rotated_target_pos - rotated_drived_pos
|
||||
|
||||
print("\nMove arm to rest position")
|
||||
print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest"))
|
||||
input("Press Enter to continue...")
|
||||
print()
|
||||
|
||||
# Joints with rotational motions are expressed in degrees in nominal range of [-180, 180]
|
||||
calib_modes = []
|
||||
for name in arm.motor_names:
|
||||
if name == "gripper":
|
||||
calib_modes.append(CalibrationMode.LINEAR.name)
|
||||
else:
|
||||
calib_modes.append(CalibrationMode.DEGREE.name)
|
||||
|
||||
calib_dict = {
|
||||
"homing_offset": homing_offset.tolist(),
|
||||
"drive_mode": drive_mode.tolist(),
|
||||
"start_pos": zero_pos.tolist(),
|
||||
"end_pos": rotated_pos.tolist(),
|
||||
"calib_mode": calib_modes,
|
||||
"motor_names": arm.motor_names,
|
||||
}
|
||||
return calib_dict
|
||||
@@ -1,208 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from dataclasses import replace
|
||||
|
||||
import torch
|
||||
from stretch_body.gamepad_teleop import GamePadTeleop
|
||||
from stretch_body.robot import Robot as StretchAPI
|
||||
from stretch_body.robot_params import RobotParams
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import StretchRobotConfig
|
||||
|
||||
|
||||
class StretchRobot(StretchAPI):
|
||||
"""Wrapper of stretch_body.robot.Robot"""
|
||||
|
||||
def __init__(self, config: StretchRobotConfig | None = None, **kwargs):
|
||||
super().__init__()
|
||||
if config is None:
|
||||
self.config = StretchRobotConfig(**kwargs)
|
||||
else:
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
|
||||
self.robot_type = self.config.type
|
||||
self.cameras = self.config.cameras
|
||||
self.is_connected = False
|
||||
self.teleop = None
|
||||
self.logs = {}
|
||||
|
||||
# TODO(aliberts): test this
|
||||
RobotParams.set_logging_level("WARNING")
|
||||
RobotParams.set_logging_formatter("brief_console_formatter")
|
||||
|
||||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.startup()
|
||||
if not self.is_connected:
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
raise ConnectionError()
|
||||
|
||||
for name in self.cameras:
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.run_calibration()
|
||||
|
||||
def run_calibration(self) -> None:
|
||||
if not self.is_homed():
|
||||
self.home()
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None:
|
||||
self.teleop = GamePadTeleop(robot_instance=False)
|
||||
self.teleop.startup(robot=self)
|
||||
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
action = self.teleop.gamepad_controller.get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
self.teleop.do_motion(robot=self)
|
||||
self.push_command()
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
state = torch.as_tensor(list(state.values()))
|
||||
action = torch.as_tensor(list(action.values()))
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict, action_dict = {}, {}
|
||||
obs_dict["observation.state"] = state
|
||||
action_dict["action"] = action
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def get_state(self) -> dict:
|
||||
status = self.get_status()
|
||||
return {
|
||||
"head_pan.pos": status["head"]["head_pan"]["pos"],
|
||||
"head_tilt.pos": status["head"]["head_tilt"]["pos"],
|
||||
"lift.pos": status["lift"]["pos"],
|
||||
"arm.pos": status["arm"]["pos"],
|
||||
"wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"],
|
||||
"wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"],
|
||||
"wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"],
|
||||
"gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"],
|
||||
"base_x.vel": status["base"]["x_vel"],
|
||||
"base_y.vel": status["base"]["y_vel"],
|
||||
"base_theta.vel": status["base"]["theta_vel"],
|
||||
}
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
state = torch.as_tensor(list(state.values()))
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].async_read()
|
||||
images[name] = torch.from_numpy(images[name])
|
||||
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
# Populate output dictionaries
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(aliberts): return ndarrays instead of torch.Tensors
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None:
|
||||
self.teleop = GamePadTeleop(robot_instance=False)
|
||||
self.teleop.startup(robot=self)
|
||||
|
||||
if self.action_keys is None:
|
||||
dummy_action = self.teleop.gamepad_controller.get_state()
|
||||
self.action_keys = list(dummy_action.keys())
|
||||
|
||||
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
self.teleop.do_motion(state=action_dict, robot=self)
|
||||
self.push_command()
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
# TODO(aliberts): return action_sent when motion is limited
|
||||
return action
|
||||
|
||||
def print_logs(self) -> None:
|
||||
pass
|
||||
# TODO(aliberts): move robot-specific logs logic here
|
||||
|
||||
def teleop_safety_stop(self) -> None:
|
||||
if self.teleop is not None:
|
||||
self.teleop._safety_stop(robot=self)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.stop()
|
||||
if self.teleop is not None:
|
||||
self.teleop.gamepad_controller.stop()
|
||||
self.teleop.stop()
|
||||
|
||||
if len(self.cameras) > 0:
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
def __del__(self):
|
||||
self.disconnect()
|
||||
@@ -1,88 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robot_devices.robots.configs import (
|
||||
AlohaRobotConfig,
|
||||
KochBimanualRobotConfig,
|
||||
KochRobotConfig,
|
||||
LeKiwiRobotConfig,
|
||||
ManipulatorRobotConfig,
|
||||
MossRobotConfig,
|
||||
RobotConfig,
|
||||
So100RobotConfig,
|
||||
StretchRobotConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_arm_id(name, arm_type):
|
||||
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
|
||||
like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
|
||||
"""
|
||||
return f"{name}_{arm_type}"
|
||||
|
||||
|
||||
class Robot(Protocol):
|
||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
def teleop_step(self, record_data=False): ...
|
||||
def capture_observation(self): ...
|
||||
def send_action(self, action): ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
return AlohaRobotConfig(**kwargs)
|
||||
elif robot_type == "koch":
|
||||
return KochRobotConfig(**kwargs)
|
||||
elif robot_type == "koch_bimanual":
|
||||
return KochBimanualRobotConfig(**kwargs)
|
||||
elif robot_type == "moss":
|
||||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100":
|
||||
return So100RobotConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
return StretchRobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig):
|
||||
if isinstance(config, ManipulatorRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot
|
||||
|
||||
return ManipulatorRobot(config)
|
||||
elif isinstance(config, LeKiwiRobotConfig):
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import (
|
||||
MobileManipulator,
|
||||
)
|
||||
|
||||
return MobileManipulator(config)
|
||||
else:
|
||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||
|
||||
return StretchRobot(config)
|
||||
|
||||
|
||||
def make_robot(robot_type: str, **kwargs) -> Robot:
|
||||
config = make_robot_config(robot_type, **kwargs)
|
||||
return make_robot_from_config(config)
|
||||
4
lerobot/common/robots/__init__.py
Normal file
4
lerobot/common/robots/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .config import RobotConfig
|
||||
from .robot import Robot
|
||||
|
||||
__all__ = ["RobotConfig", "Robot"]
|
||||
17
lerobot/common/robots/config.py
Normal file
17
lerobot/common/robots/config.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RobotConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
# Allows to distinguish between different robots of the same type
|
||||
id: str | None = None
|
||||
# Directory to store calibration file
|
||||
calibration_dir: Path | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
2
lerobot/common/robots/koch/__init__.py
Normal file
2
lerobot/common/robots/koch/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config_koch_follower import KochFollowerConfig
|
||||
from .koch_follower import KochFollower
|
||||
22
lerobot/common/robots/koch/config_koch_follower.py
Normal file
22
lerobot/common/robots/koch/config_koch_follower.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("koch_follower")
|
||||
@dataclass
|
||||
class KochFollowerConfig(RobotConfig):
|
||||
# Port to connect to the arm
|
||||
port: str
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
230
lerobot/common/robots/koch/koch_follower.py
Normal file
230
lerobot/common/robots/koch/koch_follower.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_koch_follower import KochFollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KochFollower(Robot):
|
||||
"""
|
||||
- [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow
|
||||
expansion, developed by Alexander Koch from [Tau Robotics](https://tau-robotics.com)
|
||||
- [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss
|
||||
"""
|
||||
|
||||
config_class = KochFollowerConfig
|
||||
name = "koch_follower"
|
||||
|
||||
def __init__(self, config: KochFollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.arm = DynamixelMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_lift": Motor(2, "xl430-w250", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_flex": Motor(3, "xl330-m288", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_flex": Motor(4, "xl330-m288", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_roll": Motor(5, "xl330-m288", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(6, "xl330-m288", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def state_feature(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(self.arm),),
|
||||
"names": {"motors": list(self.arm.motors)},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
return self.state_feature
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
cam_ft[cam_key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
# TODO(aliberts): add cam.is_connected for cam in self.cameras
|
||||
return self.arm.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.arm.connect()
|
||||
if not self.is_calibrated:
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.arm.disable_torque()
|
||||
for name in self.arm.names:
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.arm.set_half_turn_homings()
|
||||
|
||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||
logger.info(
|
||||
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
||||
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self.arm.record_ranges_of_motion(unknown_range_motors)
|
||||
for name in full_turn_motors:
|
||||
range_mins[name] = 0
|
||||
range_maxes[name] = 4095
|
||||
|
||||
self.calibration = {}
|
||||
for name, motor in self.arm.motors.items():
|
||||
self.calibration[name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=homing_offsets[name],
|
||||
range_min=range_mins[name],
|
||||
range_max=range_maxes[name],
|
||||
)
|
||||
|
||||
self.arm.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
with self.arm.torque_disabled():
|
||||
self.arm.configure_motors()
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point
|
||||
for name in self.arm.names:
|
||||
if name != "gripper":
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
# Use 'position control current based' for gripper to be limited by the limit of the current. For
|
||||
# the follower gripper, it means it can grasp an object without forcing too much even tho, its
|
||||
# goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
# For the leader gripper, it means we can use it as a physical trigger, since we can force with
|
||||
# our finger to make it move, and it will move back to its original target position when we
|
||||
# release the force.
|
||||
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
|
||||
# Set better PID values to close the gap between recorded states and actions
|
||||
# TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor
|
||||
self.arm.write("Position_P_Gain", "elbow_flex", 1500)
|
||||
self.arm.write("Position_I_Gain", "elbow_flex", 0)
|
||||
self.arm.write("Position_D_Gain", "elbow_flex", 600)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict[OBS_STATE] = self.arm.sync_read("Present_Position")
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> dict[str, float]:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`. In this case, the action sent differs from original action.
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action (dict[str, float]): The goal positions for the motors.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = action
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.arm.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# Send goal position to the arm
|
||||
self.arm.sync_write("Goal_Position", goal_pos)
|
||||
return goal_pos
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.arm.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
89
lerobot/common/robots/lekiwi/configuration_lekiwi.py
Normal file
89
lerobot/common/robots/lekiwi/configuration_lekiwi.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras.configs import CameraConfig
|
||||
from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.common.motors.configs import FeetechMotorsBusConfig, MotorsBusConfig
|
||||
from lerobot.common.robots.config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("lekiwi")
|
||||
@dataclass
|
||||
class LeKiwiRobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# Network Configuration
|
||||
ip: str = "192.168.0.193"
|
||||
port: int = 5555
|
||||
video_port: int = 5556
|
||||
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"front": OpenCVCameraConfig(
|
||||
camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90
|
||||
),
|
||||
"wrist": OpenCVCameraConfig(
|
||||
camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
calibration_dir: str = ".cache/calibration/lekiwi"
|
||||
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/tty.usbmodem585A0077581",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"main": FeetechMotorsBusConfig(
|
||||
port="/dev/ttyACM0",
|
||||
motors={
|
||||
# name: (index, model)
|
||||
"shoulder_pan": [1, "sts3215"],
|
||||
"shoulder_lift": [2, "sts3215"],
|
||||
"elbow_flex": [3, "sts3215"],
|
||||
"wrist_flex": [4, "sts3215"],
|
||||
"wrist_roll": [5, "sts3215"],
|
||||
"gripper": [6, "sts3215"],
|
||||
"left_wheel": (7, "sts3215"),
|
||||
"back_wheel": (8, "sts3215"),
|
||||
"right_wheel": (9, "sts3215"),
|
||||
},
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
teleop_keys: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
# Movement
|
||||
"forward": "w",
|
||||
"backward": "s",
|
||||
"left": "a",
|
||||
"right": "d",
|
||||
"rotate_left": "z",
|
||||
"rotate_right": "x",
|
||||
# Speed control
|
||||
"speed_up": "r",
|
||||
"speed_down": "f",
|
||||
# quit teleop
|
||||
"quit": "q",
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
@@ -21,7 +21,7 @@ from pathlib import Path
|
||||
import cv2
|
||||
import zmq
|
||||
|
||||
from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi
|
||||
from lerobot.common.robots.mobile_manipulator import LeKiwi
|
||||
|
||||
|
||||
def setup_zmq_sockets(config):
|
||||
@@ -61,9 +61,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
calib_dir.mkdir(parents=True, exist_ok=True)
|
||||
calib_file = calib_dir / "main_follower.json"
|
||||
try:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
||||
except ImportError:
|
||||
print("[WARNING] Calibration function not available. Skipping calibration.")
|
||||
return
|
||||
@@ -74,7 +72,7 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
calibration = run_full_arm_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
@@ -95,8 +93,8 @@ def run_lekiwi(robot_config):
|
||||
- Processes incoming commands (arm and wheel commands) and sends back sensor and camera data.
|
||||
"""
|
||||
# Import helper functions and classes
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.motors.feetech.feetech import FeetechMotorsBus, TorqueMode
|
||||
|
||||
# Initialize cameras from the robot configuration.
|
||||
cameras = make_cameras_from_configs(robot_config.cameras)
|
||||
@@ -118,14 +116,7 @@ def run_lekiwi(robot_config):
|
||||
robot = LeKiwi(motors_bus)
|
||||
|
||||
# Define the expected arm motor IDs.
|
||||
arm_motor_ids = [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
]
|
||||
arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"]
|
||||
|
||||
# Disable torque for each arm motor.
|
||||
for motor in arm_motor_ids:
|
||||
@@ -139,9 +130,7 @@ def run_lekiwi(robot_config):
|
||||
images_lock = threading.Lock()
|
||||
stop_event = threading.Event()
|
||||
cam_thread = threading.Thread(
|
||||
target=run_camera_capture,
|
||||
args=(cameras, images_lock, latest_images_dict, stop_event),
|
||||
daemon=True,
|
||||
target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True
|
||||
)
|
||||
cam_thread.start()
|
||||
|
||||
692
lerobot/common/robots/lekiwi/robot_lekiwi.py
Normal file
692
lerobot/common/robots/lekiwi/robot_lekiwi.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceNotConnectedError
|
||||
from lerobot.common.motors.feetech.feetech import TorqueMode
|
||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
||||
from lerobot.common.motors.motors_bus import MotorsBus
|
||||
from lerobot.common.motors.utils import make_motors_buses_from_configs
|
||||
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
||||
from lerobot.common.robots.utils import get_arm_id
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
# Only import if there's a valid X server or if we're not on a Pi
|
||||
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
|
||||
print("No DISPLAY set. Skipping pynput import.")
|
||||
raise ImportError("pynput blocked intentionally due to no display.")
|
||||
|
||||
from pynput import keyboard
|
||||
except ImportError:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
except Exception as e:
|
||||
keyboard = None
|
||||
PYNPUT_AVAILABLE = False
|
||||
print(f"Could not import pynput: {e}")
|
||||
|
||||
|
||||
class MobileManipulator:
|
||||
"""
|
||||
MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot.
|
||||
The robot includes a three omniwheel mobile base and a remote follower arm.
|
||||
The leader arm is connected locally (on the laptop) and its joint positions are recorded and then
|
||||
forwarded to the remote follower arm (after applying a safety clamp).
|
||||
In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LeKiwiRobotConfig):
|
||||
"""
|
||||
Expected keys in config:
|
||||
- ip, port, video_port for the remote connection.
|
||||
- calibration_dir, leader_arms, follower_arms, max_relative_target, etc.
|
||||
"""
|
||||
self.robot_type = config.type
|
||||
self.config = config
|
||||
self.remote_ip = config.ip
|
||||
self.remote_port = config.port
|
||||
self.remote_port_video = config.video_port
|
||||
self.calibration_dir = Path(self.config.calibration_dir)
|
||||
self.logs = {}
|
||||
|
||||
self.teleop_keys = self.config.teleop_keys
|
||||
|
||||
# For teleoperation, the leader arm (local) is used to record the desired arm pose.
|
||||
self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms)
|
||||
|
||||
self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms)
|
||||
|
||||
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||
|
||||
self.is_connected = False
|
||||
|
||||
self.last_frames = {}
|
||||
self.last_present_speed = {}
|
||||
self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32)
|
||||
|
||||
# Define three speed levels and a current index
|
||||
self.speed_levels = [
|
||||
{"xy": 0.1, "theta": 30}, # slow
|
||||
{"xy": 0.2, "theta": 60}, # medium
|
||||
{"xy": 0.3, "theta": 90}, # fast
|
||||
]
|
||||
self.speed_index = 0 # Start at slow
|
||||
|
||||
# ZeroMQ context and sockets.
|
||||
self.context = None
|
||||
self.cmd_socket = None
|
||||
self.video_socket = None
|
||||
|
||||
# Keyboard state for base teleoperation.
|
||||
self.running = True
|
||||
self.pressed_keys = {
|
||||
"forward": False,
|
||||
"backward": False,
|
||||
"left": False,
|
||||
"right": False,
|
||||
"rotate_left": False,
|
||||
"rotate_right": False,
|
||||
}
|
||||
|
||||
if PYNPUT_AVAILABLE:
|
||||
print("pynput is available - enabling local keyboard listener.")
|
||||
self.listener = keyboard.Listener(
|
||||
on_press=self.on_press,
|
||||
on_release=self.on_release,
|
||||
)
|
||||
self.listener.start()
|
||||
else:
|
||||
print("pynput not available - skipping local keyboard listener.")
|
||||
self.listener = None
|
||||
|
||||
def get_motor_names(self, arms: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
follower_arm_names = [
|
||||
"shoulder_pan",
|
||||
"shoulder_lift",
|
||||
"elbow_flex",
|
||||
"wrist_flex",
|
||||
"wrist_roll",
|
||||
"gripper",
|
||||
]
|
||||
observations = ["x_mm", "y_mm", "theta"]
|
||||
combined_names = follower_arm_names + observations
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(combined_names),),
|
||||
"names": combined_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(combined_names),),
|
||||
"names": combined_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
|
||||
@property
|
||||
def num_cameras(self):
|
||||
return len(self.cameras)
|
||||
|
||||
@property
|
||||
def available_arms(self):
|
||||
available = []
|
||||
for name in self.leader_arms:
|
||||
available.append(get_arm_id(name, "leader"))
|
||||
for name in self.follower_arms:
|
||||
available.append(get_arm_id(name, "follower"))
|
||||
return available
|
||||
|
||||
def on_press(self, key):
|
||||
try:
|
||||
# Movement
|
||||
if key.char == self.teleop_keys["forward"]:
|
||||
self.pressed_keys["forward"] = True
|
||||
elif key.char == self.teleop_keys["backward"]:
|
||||
self.pressed_keys["backward"] = True
|
||||
elif key.char == self.teleop_keys["left"]:
|
||||
self.pressed_keys["left"] = True
|
||||
elif key.char == self.teleop_keys["right"]:
|
||||
self.pressed_keys["right"] = True
|
||||
elif key.char == self.teleop_keys["rotate_left"]:
|
||||
self.pressed_keys["rotate_left"] = True
|
||||
elif key.char == self.teleop_keys["rotate_right"]:
|
||||
self.pressed_keys["rotate_right"] = True
|
||||
|
||||
# Quit teleoperation
|
||||
elif key.char == self.teleop_keys["quit"]:
|
||||
self.running = False
|
||||
return False
|
||||
|
||||
# Speed control
|
||||
elif key.char == self.teleop_keys["speed_up"]:
|
||||
self.speed_index = min(self.speed_index + 1, 2)
|
||||
print(f"Speed index increased to {self.speed_index}")
|
||||
elif key.char == self.teleop_keys["speed_down"]:
|
||||
self.speed_index = max(self.speed_index - 1, 0)
|
||||
print(f"Speed index decreased to {self.speed_index}")
|
||||
|
||||
except AttributeError:
|
||||
# e.g., if key is special like Key.esc
|
||||
if key == keyboard.Key.esc:
|
||||
self.running = False
|
||||
return False
|
||||
|
||||
def on_release(self, key):
|
||||
try:
|
||||
if hasattr(key, "char"):
|
||||
if key.char == self.teleop_keys["forward"]:
|
||||
self.pressed_keys["forward"] = False
|
||||
elif key.char == self.teleop_keys["backward"]:
|
||||
self.pressed_keys["backward"] = False
|
||||
elif key.char == self.teleop_keys["left"]:
|
||||
self.pressed_keys["left"] = False
|
||||
elif key.char == self.teleop_keys["right"]:
|
||||
self.pressed_keys["right"] = False
|
||||
elif key.char == self.teleop_keys["rotate_left"]:
|
||||
self.pressed_keys["rotate_left"] = False
|
||||
elif key.char == self.teleop_keys["rotate_right"]:
|
||||
self.pressed_keys["rotate_right"] = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
if not self.leader_arms:
|
||||
raise ValueError("MobileManipulator has no leader arm to connect.")
|
||||
for name in self.leader_arms:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.calibrate_leader()
|
||||
|
||||
# Set up ZeroMQ sockets to communicate with the remote mobile robot.
|
||||
self.context = zmq.Context()
|
||||
self.cmd_socket = self.context.socket(zmq.PUSH)
|
||||
connection_string = f"tcp://{self.remote_ip}:{self.remote_port}"
|
||||
self.cmd_socket.connect(connection_string)
|
||||
self.cmd_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
self.video_socket = self.context.socket(zmq.PULL)
|
||||
video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}"
|
||||
self.video_socket.connect(video_connection)
|
||||
self.video_socket.setsockopt(zmq.CONFLATE, 1)
|
||||
print(
|
||||
f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}."
|
||||
)
|
||||
self.is_connected = True
|
||||
|
||||
def load_or_run_calibration_(self, name, arm, arm_type):
|
||||
arm_id = get_arm_id(name, arm_type)
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
|
||||
def calibrate_leader(self):
|
||||
for name, arm in self.leader_arms.items():
|
||||
# Connect the bus
|
||||
arm.connect()
|
||||
|
||||
# Disable torque on all motors
|
||||
for motor_id in arm.motors:
|
||||
arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id)
|
||||
|
||||
# Now run calibration
|
||||
calibration = self.load_or_run_calibration_(name, arm, "leader")
|
||||
arm.set_calibration(calibration)
|
||||
|
||||
def calibrate_follower(self):
|
||||
for name, bus in self.follower_arms.items():
|
||||
bus.connect()
|
||||
|
||||
# Disable torque on all motors
|
||||
for motor_id in bus.motors:
|
||||
bus.write("Torque_Enable", 0, motor_id)
|
||||
|
||||
# Then filter out wheels
|
||||
arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")}
|
||||
if not arm_only_dict:
|
||||
continue
|
||||
|
||||
original_motors = bus.motors
|
||||
bus.motors = arm_only_dict
|
||||
|
||||
calibration = self.load_or_run_calibration_(name, bus, "follower")
|
||||
bus.set_calibration(calibration)
|
||||
|
||||
bus.motors = original_motors
|
||||
|
||||
def _get_data(self):
|
||||
"""
|
||||
Polls the video socket for up to 15 ms. If data arrives, decode only
|
||||
the *latest* message, returning frames, speed, and arm state. If
|
||||
nothing arrives for any field, use the last known values.
|
||||
"""
|
||||
frames = {}
|
||||
present_speed = {}
|
||||
remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32)
|
||||
|
||||
# Poll up to 15 ms
|
||||
poller = zmq.Poller()
|
||||
poller.register(self.video_socket, zmq.POLLIN)
|
||||
socks = dict(poller.poll(15))
|
||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||
# No new data arrived → reuse ALL old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Drain all messages, keep only the last
|
||||
last_msg = None
|
||||
while True:
|
||||
try:
|
||||
obs_string = self.video_socket.recv_string(zmq.NOBLOCK)
|
||||
last_msg = obs_string
|
||||
except zmq.Again:
|
||||
break
|
||||
|
||||
if not last_msg:
|
||||
# No new message → also reuse old
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Decode only the final message
|
||||
try:
|
||||
observation = json.loads(last_msg)
|
||||
|
||||
images_dict = observation.get("images", {})
|
||||
new_speed = observation.get("present_speed", {})
|
||||
new_arm_state = observation.get("follower_arm_state", None)
|
||||
|
||||
# Convert images
|
||||
for cam_name, image_b64 in images_dict.items():
|
||||
if image_b64:
|
||||
jpg_data = base64.b64decode(image_b64)
|
||||
np_arr = np.frombuffer(jpg_data, dtype=np.uint8)
|
||||
frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if frame_candidate is not None:
|
||||
frames[cam_name] = frame_candidate
|
||||
|
||||
# If remote_arm_state is None and frames is None there is no message then use the previous message
|
||||
if new_arm_state is not None and frames is not None:
|
||||
self.last_frames = frames
|
||||
|
||||
remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32)
|
||||
self.last_remote_arm_state = remote_arm_state_tensor
|
||||
|
||||
present_speed = new_speed
|
||||
self.last_present_speed = new_speed
|
||||
else:
|
||||
frames = self.last_frames
|
||||
|
||||
remote_arm_state_tensor = self.last_remote_arm_state
|
||||
|
||||
present_speed = self.last_present_speed
|
||||
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
return frames, present_speed, remote_arm_state_tensor
|
||||
|
||||
def _process_present_speed(self, present_speed: dict) -> torch.Tensor:
|
||||
state_tensor = torch.zeros(3, dtype=torch.int32)
|
||||
if present_speed:
|
||||
decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()}
|
||||
if "1" in decoded:
|
||||
state_tensor[0] = decoded["1"]
|
||||
if "2" in decoded:
|
||||
state_tensor[1] = decoded["2"]
|
||||
if "3" in decoded:
|
||||
state_tensor[2] = decoded["3"]
|
||||
return state_tensor
|
||||
|
||||
def teleop_step(
|
||||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90
|
||||
|
||||
# Prepare to assign the position of the leader to the follower
|
||||
arm_positions = []
|
||||
for name in self.leader_arms:
|
||||
pos = self.leader_arms[name].read("Present_Position")
|
||||
pos_tensor = torch.from_numpy(pos).float()
|
||||
# Instead of pos_tensor.item(), use tolist() to convert the entire tensor to a list
|
||||
arm_positions.extend(pos_tensor.tolist())
|
||||
|
||||
# (The rest of your code for generating wheel commands remains unchanged)
|
||||
x_cmd = 0.0 # m/s forward/backward
|
||||
y_cmd = 0.0 # m/s lateral
|
||||
theta_cmd = 0.0 # deg/s rotation
|
||||
if self.pressed_keys["forward"]:
|
||||
x_cmd += xy_speed
|
||||
if self.pressed_keys["backward"]:
|
||||
x_cmd -= xy_speed
|
||||
if self.pressed_keys["left"]:
|
||||
y_cmd += xy_speed
|
||||
if self.pressed_keys["right"]:
|
||||
y_cmd -= xy_speed
|
||||
if self.pressed_keys["rotate_left"]:
|
||||
theta_cmd += theta_speed
|
||||
if self.pressed_keys["rotate_right"]:
|
||||
theta_cmd -= theta_speed
|
||||
|
||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions}
|
||||
self.cmd_socket.send_string(json.dumps(message))
|
||||
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
obs_dict = self.capture_observation()
|
||||
|
||||
arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32)
|
||||
|
||||
wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands)
|
||||
wheel_velocity_mm = (
|
||||
wheel_velocity_tuple[0] * 1000.0,
|
||||
wheel_velocity_tuple[1] * 1000.0,
|
||||
wheel_velocity_tuple[2],
|
||||
)
|
||||
wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32)
|
||||
action_tensor = torch.cat([arm_state_tensor, wheel_tensor])
|
||||
action_dict = {"action": action_tensor}
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||
|
||||
body_state = self.wheel_raw_to_body(present_speed)
|
||||
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
|
||||
obs_dict = {"observation.state": combined_state_tensor}
|
||||
|
||||
# Loop over each configured camera
|
||||
for cam_name, cam in self.cameras.items():
|
||||
frame = frames.get(cam_name, None)
|
||||
if frame is None:
|
||||
# Create a black image using the camera's configured width, height, and channels
|
||||
frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8)
|
||||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame)
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
# Ensure the action tensor has at least 9 elements:
|
||||
# - First 6: arm positions.
|
||||
# - Last 3: base commands.
|
||||
if action.numel() < 9:
|
||||
# Pad with zeros if there are not enough elements.
|
||||
padded = torch.zeros(9, dtype=action.dtype)
|
||||
padded[: action.numel()] = action
|
||||
action = padded
|
||||
|
||||
# Extract arm and base actions.
|
||||
arm_actions = action[:6].flatten()
|
||||
base_actions = action[6:].flatten()
|
||||
|
||||
x_cmd_mm = base_actions[0].item() # mm/s
|
||||
y_cmd_mm = base_actions[1].item() # mm/s
|
||||
theta_cmd = base_actions[2].item() # deg/s
|
||||
|
||||
# Convert mm/s to m/s for the kinematics calculations.
|
||||
x_cmd = x_cmd_mm / 1000.0 # m/s
|
||||
y_cmd = y_cmd_mm / 1000.0 # m/s
|
||||
|
||||
# Compute wheel commands from body commands.
|
||||
wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
arm_positions_list = arm_actions.tolist()
|
||||
|
||||
message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list}
|
||||
self.cmd_socket.send_string(json.dumps(message))
|
||||
|
||||
return action
|
||||
|
||||
def print_logs(self):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError("Not connected.")
|
||||
if self.cmd_socket:
|
||||
stop_cmd = {
|
||||
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
||||
"arm_positions": {},
|
||||
}
|
||||
self.cmd_socket.send_string(json.dumps(stop_cmd))
|
||||
self.cmd_socket.close()
|
||||
if self.video_socket:
|
||||
self.video_socket.close()
|
||||
if self.context:
|
||||
self.context.term()
|
||||
if PYNPUT_AVAILABLE:
|
||||
self.listener.stop()
|
||||
self.is_connected = False
|
||||
print("[INFO] Disconnected from remote robot.")
|
||||
|
||||
def __del__(self):
|
||||
if getattr(self, "is_connected", False):
|
||||
self.disconnect()
|
||||
if PYNPUT_AVAILABLE:
|
||||
self.listener.stop()
|
||||
|
||||
@staticmethod
|
||||
def degps_to_raw(degps: float) -> int:
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
speed_in_steps = abs(degps) * steps_per_deg
|
||||
speed_int = int(round(speed_in_steps))
|
||||
if speed_int > 0x7FFF:
|
||||
speed_int = 0x7FFF
|
||||
if degps < 0:
|
||||
return speed_int | 0x8000
|
||||
else:
|
||||
return speed_int & 0x7FFF
|
||||
|
||||
@staticmethod
|
||||
def raw_to_degps(raw_speed: int) -> float:
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
magnitude = raw_speed & 0x7FFF
|
||||
degps = magnitude / steps_per_deg
|
||||
if raw_speed & 0x8000:
|
||||
degps = -degps
|
||||
return degps
|
||||
|
||||
def body_to_wheel_raw(
|
||||
self,
|
||||
x_cmd: float,
|
||||
y_cmd: float,
|
||||
theta_cmd: float,
|
||||
wheel_radius: float = 0.05,
|
||||
base_radius: float = 0.125,
|
||||
max_raw: int = 3000,
|
||||
) -> dict:
|
||||
"""
|
||||
Convert desired body-frame velocities into wheel raw commands.
|
||||
|
||||
Parameters:
|
||||
x_cmd : Linear velocity in x (m/s).
|
||||
y_cmd : Linear velocity in y (m/s).
|
||||
theta_cmd : Rotational velocity (deg/s).
|
||||
wheel_radius: Radius of each wheel (meters).
|
||||
base_radius : Distance from the center of rotation to each wheel (meters).
|
||||
max_raw : Maximum allowed raw command (ticks) per wheel.
|
||||
|
||||
Returns:
|
||||
A dictionary with wheel raw commands:
|
||||
{"left_wheel": value, "back_wheel": value, "right_wheel": value}.
|
||||
|
||||
Notes:
|
||||
- Internally, the method converts theta_cmd to rad/s for the kinematics.
|
||||
- The raw command is computed from the wheels angular speed in deg/s
|
||||
using degps_to_raw(). If any command exceeds max_raw, all commands
|
||||
are scaled down proportionally.
|
||||
"""
|
||||
# Convert rotational velocity from deg/s to rad/s.
|
||||
theta_rad = theta_cmd * (np.pi / 180.0)
|
||||
# Create the body velocity vector [x, y, theta_rad].
|
||||
velocity_vector = np.array([x_cmd, y_cmd, theta_rad])
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
# Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed.
|
||||
# The third column (base_radius) accounts for the effect of rotation.
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
|
||||
# Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s).
|
||||
wheel_linear_speeds = m.dot(velocity_vector)
|
||||
wheel_angular_speeds = wheel_linear_speeds / wheel_radius
|
||||
|
||||
# Convert wheel angular speeds from rad/s to deg/s.
|
||||
wheel_degps = wheel_angular_speeds * (180.0 / np.pi)
|
||||
|
||||
# Scaling
|
||||
steps_per_deg = 4096.0 / 360.0
|
||||
raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps]
|
||||
max_raw_computed = max(raw_floats)
|
||||
if max_raw_computed > max_raw:
|
||||
scale = max_raw / max_raw_computed
|
||||
wheel_degps = wheel_degps * scale
|
||||
|
||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||
|
||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||
|
||||
def wheel_raw_to_body(
|
||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
) -> tuple:
|
||||
"""
|
||||
Convert wheel raw command feedback back into body-frame velocities.
|
||||
|
||||
Parameters:
|
||||
wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel").
|
||||
wheel_radius: Radius of each wheel (meters).
|
||||
base_radius : Distance from the robot center to each wheel (meters).
|
||||
|
||||
Returns:
|
||||
A tuple (x_cmd, y_cmd, theta_cmd) where:
|
||||
x_cmd : Linear velocity in x (m/s).
|
||||
y_cmd : Linear velocity in y (m/s).
|
||||
theta_cmd : Rotational velocity in deg/s.
|
||||
"""
|
||||
# Extract the raw values in order.
|
||||
raw_list = [
|
||||
int(wheel_raw.get("left_wheel", 0)),
|
||||
int(wheel_raw.get("back_wheel", 0)),
|
||||
int(wheel_raw.get("right_wheel", 0)),
|
||||
]
|
||||
|
||||
# Convert each raw command back to an angular speed in deg/s.
|
||||
wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list])
|
||||
# Convert from deg/s to rad/s.
|
||||
wheel_radps = wheel_degps * (np.pi / 180.0)
|
||||
# Compute each wheel’s linear speed (m/s) from its angular speed.
|
||||
wheel_linear_speeds = wheel_radps * wheel_radius
|
||||
|
||||
# Define the wheel mounting angles with a -90° offset.
|
||||
angles = np.radians(np.array([240, 120, 0]) - 90)
|
||||
m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles])
|
||||
|
||||
# Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds.
|
||||
m_inv = np.linalg.inv(m)
|
||||
velocity_vector = m_inv.dot(wheel_linear_speeds)
|
||||
x_cmd, y_cmd, theta_rad = velocity_vector
|
||||
theta_cmd = theta_rad * (180.0 / np.pi)
|
||||
return (x_cmd, y_cmd, theta_cmd)
|
||||
|
||||
|
||||
class LeKiwi:
|
||||
def __init__(self, motor_bus):
|
||||
"""
|
||||
Initializes the LeKiwi with Feetech motors bus.
|
||||
"""
|
||||
self.motor_bus = motor_bus
|
||||
self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"]
|
||||
|
||||
# Initialize motors in velocity mode.
|
||||
self.motor_bus.write("Lock", 0)
|
||||
self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids)
|
||||
self.motor_bus.write("Lock", 1)
|
||||
print("Motors set to velocity mode.")
|
||||
|
||||
def read_velocity(self):
|
||||
"""
|
||||
Reads the raw speeds for all wheels. Returns a dictionary with motor names:
|
||||
"""
|
||||
raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids)
|
||||
return {
|
||||
"left_wheel": int(raw_speeds[0]),
|
||||
"back_wheel": int(raw_speeds[1]),
|
||||
"right_wheel": int(raw_speeds[2]),
|
||||
}
|
||||
|
||||
def set_velocity(self, command_speeds):
|
||||
"""
|
||||
Sends raw velocity commands (16-bit encoded values) directly to the motor bus.
|
||||
The order of speeds must correspond to self.motor_ids.
|
||||
"""
|
||||
self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids)
|
||||
|
||||
def stop(self):
|
||||
"""Stops the robot by setting all motor speeds to zero."""
|
||||
self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids)
|
||||
print("Motors stopped.")
|
||||
@@ -18,48 +18,66 @@ and send orders to its motors.
|
||||
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
||||
# calibration procedure, to make it easy for people to add their own robot.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import (
|
||||
RobotDeviceAlreadyConnectedError,
|
||||
RobotDeviceNotConnectedError,
|
||||
)
|
||||
from lerobot.common.cameras.configs import CameraConfig
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors.configs import MotorsBusConfig
|
||||
from lerobot.common.motors.motors_bus import MotorsBus
|
||||
from lerobot.common.motors.utils import make_motors_buses_from_configs
|
||||
from lerobot.common.robots.config import RobotConfig
|
||||
from lerobot.common.robots.utils import ensure_safe_goal_position, get_arm_id
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_pos: torch.Tensor,
|
||||
present_pos: torch.Tensor,
|
||||
max_relative_target: float | list[float],
|
||||
):
|
||||
# Cap relative action target magnitude for safety.
|
||||
diff = goal_pos - present_pos
|
||||
max_relative_target = torch.tensor(max_relative_target)
|
||||
safe_diff = torch.minimum(diff, max_relative_target)
|
||||
safe_diff = torch.maximum(safe_diff, -max_relative_target)
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
@dataclass
|
||||
class ManipulatorRobotConfig(RobotConfig):
|
||||
leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {})
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=lambda: {})
|
||||
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.debug(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
)
|
||||
# Optionally limit the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length
|
||||
# as the number of motors in your follower arms (assumes all follower arms have the same number of
|
||||
# motors).
|
||||
max_relative_target: list[float] | float | None = None
|
||||
|
||||
return safe_goal_pos
|
||||
# Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it
|
||||
# possible to squeeze the gripper and have it spring back to an open position on its own. If None, the
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.mock:
|
||||
for arm in self.leader_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for arm in self.follower_arms.values():
|
||||
if not arm.mock:
|
||||
arm.mock = True
|
||||
for cam in self.cameras.values():
|
||||
if not cam.mock:
|
||||
cam.mock = True
|
||||
|
||||
if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence):
|
||||
for name in self.follower_arms:
|
||||
if len(self.follower_arms[name].motors) != len(self.max_relative_target):
|
||||
raise ValueError(
|
||||
f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has "
|
||||
f"{len(self.follower_arms[name].motors)} motors. Please make sure that the "
|
||||
f"`max_relative_target` list has as many parameters as there are motors per arm. "
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
|
||||
|
||||
class ManipulatorRobot:
|
||||
@@ -232,7 +250,7 @@ class ManipulatorRobot:
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
raise RobotDeviceAlreadyConnectedError(
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
@@ -250,9 +268,9 @@ class ManipulatorRobot:
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
from lerobot.common.motors.dynamixel.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.motors.feetech.feetech import TorqueMode
|
||||
|
||||
# We assume that at connection time, arms are in a rest position, and torque can
|
||||
# be safely disabled to run calibration and/or set robot preset configurations.
|
||||
@@ -261,8 +279,6 @@ class ManipulatorRobot:
|
||||
for name in self.leader_arms:
|
||||
self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
|
||||
self.activate_calibration()
|
||||
|
||||
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
|
||||
if self.robot_type in ["koch", "koch_bimanual"]:
|
||||
self.set_koch_robot_preset()
|
||||
@@ -299,54 +315,9 @@ class ManipulatorRobot:
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def activate_calibration(self):
|
||||
"""After calibration all motors function in human interpretable ranges.
|
||||
Rotations are expressed in degrees in nominal range of [-180, 180],
|
||||
and linear motions (like gripper of Aloha) in nominal range of [0, 100].
|
||||
"""
|
||||
|
||||
def load_or_run_calibration_(name, arm, arm_type):
|
||||
arm_id = get_arm_id(name, arm_type)
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import (
|
||||
run_arm_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
|
||||
for name, arm in self.follower_arms.items():
|
||||
calibration = load_or_run_calibration_(name, arm, "follower")
|
||||
arm.set_calibration(calibration)
|
||||
for name, arm in self.leader_arms.items():
|
||||
calibration = load_or_run_calibration_(name, arm, "leader")
|
||||
arm.set_calibration(calibration)
|
||||
|
||||
def set_koch_robot_preset(self):
|
||||
def set_operating_mode_(arm):
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
from lerobot.common.motors.dynamixel.dynamixel import TorqueMode
|
||||
|
||||
if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any():
|
||||
raise ValueError("To run set robot preset, the torque must be disabled on all motors.")
|
||||
@@ -444,9 +415,6 @@ class ManipulatorRobot:
|
||||
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
||||
self.follower_arms[name].write("I_Coefficient", 0)
|
||||
self.follower_arms[name].write("D_Coefficient", 32)
|
||||
# Close the write lock so that Maximum_Acceleration gets written to EPROM address,
|
||||
# which is mandatory for Maximum_Acceleration to take effect after rebooting.
|
||||
self.follower_arms[name].write("Lock", 0)
|
||||
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
||||
# the motors. Note: this configuration is not in the official STS3215 Memory Table
|
||||
self.follower_arms[name].write("Maximum_Acceleration", 254)
|
||||
@@ -456,7 +424,7 @@ class ManipulatorRobot:
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
@@ -474,14 +442,6 @@ class ManipulatorRobot:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -544,7 +504,7 @@ class ManipulatorRobot:
|
||||
def capture_observation(self):
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
@@ -590,7 +550,7 @@ class ManipulatorRobot:
|
||||
action: tensor containing the concatenated goal positions for the follower arms.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
@@ -603,14 +563,6 @@ class ManipulatorRobot:
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
# if self.config.joint_position_relative_bounds is not None:
|
||||
# goal_pos = torch.clamp(
|
||||
# goal_pos,
|
||||
# self.config.joint_position_relative_bounds["min"],
|
||||
# self.config.joint_position_relative_bounds["max"],
|
||||
# )
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -633,7 +585,7 @@ class ManipulatorRobot:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError(
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
|
||||
@@ -23,18 +23,14 @@ import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
from lerobot.common.robot_devices.motors.utils import (
|
||||
MotorsBus,
|
||||
make_motors_buses_from_configs,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import get_arm_id
|
||||
from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.errors import DeviceNotConnectedError
|
||||
from lerobot.common.motors.feetech.feetech import TorqueMode
|
||||
from lerobot.common.motors.feetech.feetech_calibration import run_full_arm_calibration
|
||||
from lerobot.common.motors.motors_bus import MotorsBus
|
||||
from lerobot.common.motors.utils import make_motors_buses_from_configs
|
||||
from lerobot.common.robots.lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
||||
from lerobot.common.robots.utils import get_arm_id
|
||||
|
||||
PYNPUT_AVAILABLE = True
|
||||
try:
|
||||
@@ -271,7 +267,7 @@ class MobileManipulator:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
calibration = run_full_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
@@ -329,11 +325,7 @@ class MobileManipulator:
|
||||
socks = dict(poller.poll(15))
|
||||
if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN:
|
||||
# No new data arrived → reuse ALL old data
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Drain all messages, keep only the last
|
||||
last_msg = None
|
||||
@@ -346,11 +338,7 @@ class MobileManipulator:
|
||||
|
||||
if not last_msg:
|
||||
# No new message → also reuse old
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
# Decode only the final message
|
||||
try:
|
||||
@@ -388,11 +376,7 @@ class MobileManipulator:
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
return (
|
||||
self.last_frames,
|
||||
self.last_present_speed,
|
||||
self.last_remote_arm_state,
|
||||
)
|
||||
return (self.last_frames, self.last_present_speed, self.last_remote_arm_state)
|
||||
|
||||
return frames, present_speed, remote_arm_state_tensor
|
||||
|
||||
@@ -412,7 +396,7 @@ class MobileManipulator:
|
||||
self, record_data: bool = False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
raise DeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.")
|
||||
|
||||
speed_setting = self.speed_levels[self.speed_index]
|
||||
xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4
|
||||
@@ -472,17 +456,13 @@ class MobileManipulator:
|
||||
and a camera frame.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
frames, present_speed, remote_arm_state_tensor = self._get_data()
|
||||
|
||||
body_state = self.wheel_raw_to_body(present_speed)
|
||||
|
||||
body_state_mm = (
|
||||
body_state[0] * 1000.0,
|
||||
body_state[1] * 1000.0,
|
||||
body_state[2],
|
||||
) # Convert x,y to mm/s
|
||||
body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s
|
||||
wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32)
|
||||
combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0)
|
||||
|
||||
@@ -500,7 +480,7 @@ class MobileManipulator:
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
raise DeviceNotConnectedError("Not connected. Run `connect()` first.")
|
||||
|
||||
# Ensure the action tensor has at least 9 elements:
|
||||
# - First 6: arm positions.
|
||||
@@ -538,7 +518,7 @@ class MobileManipulator:
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise RobotDeviceNotConnectedError("Not connected.")
|
||||
raise DeviceNotConnectedError("Not connected.")
|
||||
if self.cmd_socket:
|
||||
stop_cmd = {
|
||||
"raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0},
|
||||
@@ -641,11 +621,7 @@ class MobileManipulator:
|
||||
# Convert each wheel’s angular speed (deg/s) to a raw integer.
|
||||
wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps]
|
||||
|
||||
return {
|
||||
"left_wheel": wheel_raw[0],
|
||||
"back_wheel": wheel_raw[1],
|
||||
"right_wheel": wheel_raw[2],
|
||||
}
|
||||
return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]}
|
||||
|
||||
def wheel_raw_to_body(
|
||||
self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125
|
||||
@@ -138,7 +138,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
--id 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
@@ -150,7 +150,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
--id 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
4
lerobot/common/robots/moss/__init__.py
Normal file
4
lerobot/common/robots/moss/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .configuration_moss import MossRobotConfig
|
||||
from .robot_moss import MossRobot
|
||||
|
||||
__all__ = ["MossRobotConfig", "MossRobot"]
|
||||
30
lerobot/common/robots/moss/configuration_moss.py
Normal file
30
lerobot/common/robots/moss/configuration_moss.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("moss")
|
||||
@dataclass
|
||||
class MossRobotConfig(RobotConfig):
|
||||
# Port to connect to the robot
|
||||
port: str
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
mock: bool = False
|
||||
|
||||
# motors
|
||||
shoulder_pan: tuple = (1, "sts3215")
|
||||
shoulder_lift: tuple = (2, "sts3215")
|
||||
elbow_flex: tuple = (3, "sts3215")
|
||||
wrist_flex: tuple = (4, "sts3215")
|
||||
wrist_roll: tuple = (5, "sts3215")
|
||||
gripper: tuple = (6, "sts3215")
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
223
lerobot/common/robots/moss/robot_moss.py
Normal file
223
lerobot/common/robots/moss/robot_moss.py
Normal file
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors import TorqueMode
|
||||
from lerobot.common.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
apply_feetech_offsets_from_calibration,
|
||||
run_full_arm_calibration,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .configuration_moss import MossRobotConfig
|
||||
|
||||
|
||||
class MossRobot(Robot):
|
||||
"""
|
||||
[Moss Arm](https://github.com/jess-moss/moss-robot-arms) designed by Jess Moss
|
||||
"""
|
||||
|
||||
config_class = MossRobotConfig
|
||||
name = "moss"
|
||||
|
||||
def __init__(self, config: MossRobotConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.robot_type = config.type
|
||||
|
||||
self.arm = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": config.shoulder_pan,
|
||||
"shoulder_lift": config.shoulder_lift,
|
||||
"elbow_flex": config.elbow_flex,
|
||||
"wrist_flex": config.wrist_flex,
|
||||
"wrist_roll": config.wrist_roll,
|
||||
"gripper": config.gripper,
|
||||
},
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
@property
|
||||
def state_feature(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(self.arm),),
|
||||
"names": {"motors": list(self.arm.motors)},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
return self.state_feature
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
cam_ft[cam_key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"ManipulatorRobot is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
logging.info("Connecting arm.")
|
||||
self.arm.connect()
|
||||
|
||||
# We assume that at connection time, arm is in a rest position,
|
||||
# and torque can be safely disabled to run calibration.
|
||||
self.arm.write("Torque_Enable", TorqueMode.DISABLED.value)
|
||||
self.calibrate()
|
||||
|
||||
# Mode=0 for Position Control
|
||||
self.arm.write("Mode", 0)
|
||||
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
|
||||
self.arm.write("P_Coefficient", 16)
|
||||
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
||||
self.arm.write("I_Coefficient", 0)
|
||||
self.arm.write("D_Coefficient", 32)
|
||||
# Close the write lock so that Maximum_Acceleration gets written to EPROM address,
|
||||
# which is mandatory for Maximum_Acceleration to take effect after rebooting.
|
||||
self.arm.write("Lock", 0)
|
||||
# Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of
|
||||
# the motors. Note: this configuration is not in the official STS3215 Memory Table
|
||||
self.arm.write("Maximum_Acceleration", 254)
|
||||
self.arm.write("Acceleration", 254)
|
||||
|
||||
logging.info("Activating torque.")
|
||||
self.arm.write("Torque_Enable", TorqueMode.ENABLED.value)
|
||||
|
||||
# Check arm can be read
|
||||
self.arm.read("Present_Position")
|
||||
|
||||
# Connect the cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""After calibration all motors function in human interpretable ranges.
|
||||
Rotations are expressed in degrees in nominal range of [-180, 180],
|
||||
and linear motions (like gripper of Aloha) in nominal range of [0, 100].
|
||||
"""
|
||||
if self.calibration_fpath.exists():
|
||||
with open(self.calibration_fpath) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
logging.info(f"Missing calibration file '{self.calibration_fpath}'")
|
||||
calibration = run_full_arm_calibration(self.arm, self.robot_type, self.name, "follower")
|
||||
|
||||
logging.info(f"Calibration is done! Saving calibration file '{self.calibration_fpath}'")
|
||||
self.calibration_fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.calibration_fpath, "w") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
self.arm.set_calibration(calibration)
|
||||
apply_feetech_offsets_from_calibration(self.arm, calibration)
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read arm position
|
||||
before_read_t = time.perf_counter()
|
||||
obs_dict[OBS_STATE] = self.arm.read("Present_Position")
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
before_camread_t = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: np.ndarray) -> np.ndarray:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`. In this case, the action sent differs from original action.
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action (np.ndarray): array containing the goal positions for the motors.
|
||||
|
||||
Raises:
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
Returns:
|
||||
np.ndarray: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
goal_pos = action
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.arm.read("Present_Position")
|
||||
goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target)
|
||||
|
||||
# Send goal position to the arm
|
||||
self.arm.write("Goal_Position", goal_pos.astype(np.int32))
|
||||
|
||||
return goal_pos
|
||||
|
||||
def print_logs(self):
|
||||
# TODO(aliberts): move robot-specific logs logic here
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
|
||||
self.arm.disconnect()
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
95
lerobot/common/robots/robot.py
Normal file
95
lerobot/common/robots/robot.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import abc
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.common.constants import HF_LEROBOT_CALIBRATION, ROBOTS
|
||||
from lerobot.common.motors import MotorCalibration
|
||||
|
||||
from .config import RobotConfig
|
||||
|
||||
|
||||
# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ?
|
||||
# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23
|
||||
class Robot(abc.ABC):
|
||||
"""The main LeRobot class for implementing robots."""
|
||||
|
||||
# Set these in ALL subclasses
|
||||
config_class: RobotConfig
|
||||
name: str
|
||||
|
||||
def __init__(self, config: RobotConfig):
|
||||
self.robot_type = self.name
|
||||
self.id = config.id
|
||||
self.calibration_dir = (
|
||||
config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name
|
||||
)
|
||||
self.calibration_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.calibration_fpath = self.calibration_dir / f"{self.id}.json"
|
||||
self.calibration: dict[str, MotorCalibration] = {}
|
||||
if self.calibration_fpath.is_file():
|
||||
self._load_calibration()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.id} {self.__class__.__name__}"
|
||||
|
||||
# TODO(aliberts): create a proper Feature class for this that links with datasets
|
||||
@abc.abstractproperty
|
||||
def state_feature(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def action_feature(self) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def is_connected(self) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def connect(self) -> None:
|
||||
"""Connects to the robot."""
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def calibrate(self) -> None:
|
||||
"""Calibrates the robot."""
|
||||
pass
|
||||
|
||||
def _load_calibration(self, fpath: Path | None = None) -> None:
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath) as f, draccus.config_type("json"):
|
||||
self.calibration = draccus.load(dict[str, MotorCalibration], f)
|
||||
|
||||
def _save_calibration(self, fpath: Path | None = None) -> None:
|
||||
fpath = self.calibration_fpath if fpath is None else fpath
|
||||
with open(fpath, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self.calibration, f, indent=4)
|
||||
|
||||
@abc.abstractmethod
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
"""Gets observation from the robot."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sends actions to the robot."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnects from the robot."""
|
||||
pass
|
||||
@@ -191,7 +191,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
--id 1
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
@@ -204,7 +204,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
--id 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
2
lerobot/common/robots/so100/__init__.py
Normal file
2
lerobot/common/robots/so100/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
from .so100_follower import SO100Follower
|
||||
22
lerobot/common/robots/so100/config_so100_follower.py
Normal file
22
lerobot/common/robots/so100/config_so100_follower.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("so100_follower")
|
||||
@dataclass
|
||||
class SO100FollowerConfig(RobotConfig):
|
||||
# Port to connect to the arm
|
||||
port: str
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
215
lerobot/common/robots/so100/so100_follower.py
Normal file
215
lerobot/common/robots/so100/so100_follower.py
Normal file
@@ -0,0 +1,215 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_so100_follower import SO100FollowerConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SO100Follower(Robot):
|
||||
"""
|
||||
[SO-100 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
|
||||
config_class = SO100FollowerConfig
|
||||
name = "so100_follower"
|
||||
|
||||
def __init__(self, config: SO100FollowerConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.arm = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
calibration=self.calibration,
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def state_feature(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(self.arm),),
|
||||
"names": {"motors": list(self.arm.motors)},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
return self.state_feature
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
cam_ft[cam_key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
# TODO(aliberts): add cam.is_connected for cam in self.cameras
|
||||
return self.arm.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.arm.connect()
|
||||
if not self.is_calibrated:
|
||||
self.calibrate()
|
||||
|
||||
# Connect the cameras
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.arm.disable_torque()
|
||||
for name in self.arm.names:
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
||||
|
||||
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.arm.set_half_turn_homings()
|
||||
|
||||
full_turn_motor = "wrist_roll"
|
||||
unknown_range_motors = [name for name in self.arm.names if name != full_turn_motor]
|
||||
logger.info(
|
||||
f"Move all joints except '{full_turn_motor}' sequentially through their "
|
||||
"entire ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self.arm.record_ranges_of_motion(unknown_range_motors)
|
||||
range_mins[full_turn_motor] = 0
|
||||
range_maxes[full_turn_motor] = 4095
|
||||
|
||||
self.calibration = {}
|
||||
for name, motor in self.arm.motors.items():
|
||||
self.calibration[name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=homing_offsets[name],
|
||||
range_min=range_mins[name],
|
||||
range_max=range_maxes[name],
|
||||
)
|
||||
|
||||
self.arm.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print("Calibration saved to", self.calibration_fpath)
|
||||
|
||||
def configure(self) -> None:
|
||||
with self.arm.torque_disabled():
|
||||
self.arm.configure_motors()
|
||||
for name in self.arm.names:
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.POSITION.value)
|
||||
# Set P_Coefficient to lower value to avoid shakiness (Default is 32)
|
||||
self.arm.write("P_Coefficient", name, 16)
|
||||
# Set I_Coefficient and D_Coefficient to default value 0 and 32
|
||||
self.arm.write("I_Coefficient", name, 0)
|
||||
self.arm.write("D_Coefficient", name, 32)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict[OBS_STATE] = self.arm.sync_read("Present_Position")
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`. In this case, the action sent differs from original action.
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Raises:
|
||||
RobotDeviceNotConnectedError: if robot is not connected.
|
||||
|
||||
Returns:
|
||||
the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = action
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.arm.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# Send goal position to the arm
|
||||
self.arm.sync_write("Goal_Position", goal_pos)
|
||||
return goal_pos
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.arm.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
44
lerobot/common/robots/stretch3/configuration_stretch3.py
Normal file
44
lerobot/common/robots/stretch3/configuration_stretch3.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
from lerobot.common.cameras.intel import RealSenseCameraConfig
|
||||
from lerobot.common.cameras.opencv import OpenCVCameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("stretch3")
|
||||
@dataclass
|
||||
class Stretch3RobotConfig(RobotConfig):
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: int | None = None
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(
|
||||
default_factory=lambda: {
|
||||
"navigation": OpenCVCameraConfig(
|
||||
camera_index="/dev/hello-nav-head-camera",
|
||||
fps=10,
|
||||
width=1280,
|
||||
height=720,
|
||||
rotation=-90,
|
||||
),
|
||||
"head": RealSenseCameraConfig(
|
||||
name="Intel RealSense D435I",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
rotation=90,
|
||||
),
|
||||
"wrist": RealSenseCameraConfig(
|
||||
name="Intel RealSense D405",
|
||||
fps=30,
|
||||
width=640,
|
||||
height=480,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
mock: bool = False
|
||||
183
lerobot/common/robots/stretch3/robot_stretch3.py
Normal file
183
lerobot/common/robots/stretch3/robot_stretch3.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from stretch_body.gamepad_teleop import GamePadTeleop
|
||||
from stretch_body.robot import Robot as StretchAPI
|
||||
from stretch_body.robot_params import RobotParams
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.datasets.utils import get_nested_item
|
||||
|
||||
from ..robot import Robot
|
||||
from .configuration_stretch3 import Stretch3RobotConfig
|
||||
|
||||
# {lerobot_keys: stretch.api.keys}
|
||||
STRETCH_MOTORS = {
|
||||
"head_pan.pos": "head.head_pan.pos",
|
||||
"head_tilt.pos": "head.head_tilt.pos",
|
||||
"lift.pos": "lift.pos",
|
||||
"arm.pos": "arm.pos",
|
||||
"wrist_pitch.pos": "end_of_arm.wrist_pitch.pos",
|
||||
"wrist_roll.pos": "end_of_arm.wrist_roll.pos",
|
||||
"wrist_yaw.pos": "end_of_arm.wrist_yaw.pos",
|
||||
"gripper.pos": "end_of_arm.stretch_gripper.pos",
|
||||
"base_x.vel": "base.x_vel",
|
||||
"base_y.vel": "base.y_vel",
|
||||
"base_theta.vel": "base.theta_vel",
|
||||
}
|
||||
|
||||
|
||||
class Stretch3Robot(Robot):
|
||||
"""[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot."""
|
||||
|
||||
config_class = Stretch3RobotConfig
|
||||
name = "stretch3"
|
||||
|
||||
def __init__(self, config: Stretch3RobotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.robot_type = self.config.type
|
||||
|
||||
self.api = StretchAPI()
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
self.teleop = None # TODO remove
|
||||
|
||||
# TODO(aliberts): test this
|
||||
RobotParams.set_logging_level("WARNING")
|
||||
RobotParams.set_logging_formatter("brief_console_formatter")
|
||||
|
||||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
|
||||
@property
|
||||
def state_feature(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(STRETCH_MOTORS),),
|
||||
"names": {"motors": list(STRETCH_MOTORS)},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
return self.state_feature
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
cam_ft[cam_key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
def connect(self) -> None:
|
||||
self.is_connected = self.api.startup()
|
||||
if not self.is_connected:
|
||||
print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'")
|
||||
raise ConnectionError()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
self.is_connected = self.is_connected and cam.is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
if not self.api.is_homed():
|
||||
self.api.home()
|
||||
|
||||
def _get_state(self) -> dict:
|
||||
status = self.api.get_status()
|
||||
return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()}
|
||||
|
||||
def get_observation(self) -> dict[str, np.ndarray]:
|
||||
obs_dict = {}
|
||||
|
||||
# Read Stretch state
|
||||
before_read_t = time.perf_counter()
|
||||
state = self._get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
state = np.asarray(list(state.values()))
|
||||
obs_dict[OBS_STATE] = state
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
before_camread_t = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"]
|
||||
self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: np.ndarray) -> np.ndarray:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
if self.teleop is None:
|
||||
self.teleop = GamePadTeleop(robot_instance=False)
|
||||
self.teleop.startup(robot=self)
|
||||
|
||||
if self.action_keys is None:
|
||||
dummy_action = self.teleop.gamepad_controller.get_state()
|
||||
self.action_keys = list(dummy_action.keys())
|
||||
|
||||
action_dict = dict(zip(self.action_keys, action.tolist(), strict=True))
|
||||
|
||||
before_write_t = time.perf_counter()
|
||||
self.teleop.do_motion(state=action_dict, robot=self)
|
||||
self.push_command()
|
||||
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||
|
||||
# TODO(aliberts): return action_sent when motion is limited
|
||||
return action
|
||||
|
||||
def print_logs(self) -> None:
|
||||
pass
|
||||
# TODO(aliberts): move robot-specific logs logic here
|
||||
|
||||
def teleop_safety_stop(self) -> None:
|
||||
if self.teleop is not None:
|
||||
self.teleop._safety_stop(robot=self)
|
||||
|
||||
def disconnect(self) -> None:
|
||||
self.api.stop()
|
||||
if self.teleop is not None:
|
||||
self.teleop.gamepad_controller.stop()
|
||||
self.teleop.stop()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
self.is_connected = False
|
||||
117
lerobot/common/robots/utils.py
Normal file
117
lerobot/common/robots/utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import logging
|
||||
from pprint import pformat
|
||||
from typing import Protocol
|
||||
|
||||
from lerobot.common.robots import RobotConfig
|
||||
|
||||
|
||||
def get_arm_id(name, arm_type):
|
||||
"""Returns the string identifier of a robot arm. For instance, for a bimanual manipulator
|
||||
like Aloha, it could be left_follower, right_follower, left_leader, or right_leader.
|
||||
"""
|
||||
return f"{name}_{arm_type}"
|
||||
|
||||
|
||||
# TODO(aliberts): Remove and point to lerobot.common.robots.Robot
|
||||
class Robot(Protocol):
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
def teleop_step(self, record_data=False): ...
|
||||
def capture_observation(self): ...
|
||||
def send_action(self, action): ...
|
||||
def disconnect(self): ...
|
||||
|
||||
|
||||
def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
||||
if robot_type == "aloha":
|
||||
from .aloha.configuration_aloha import AlohaRobotConfig
|
||||
|
||||
return AlohaRobotConfig(**kwargs)
|
||||
elif robot_type == "koch_follower":
|
||||
from .koch.config_koch_follower import KochFollowerConfig
|
||||
|
||||
return KochFollowerConfig(**kwargs)
|
||||
# elif robot_type == "koch_bimanual":
|
||||
# return KochBimanualRobotConfig(**kwargs)
|
||||
elif robot_type == "moss":
|
||||
from .moss.configuration_moss import MossRobotConfig
|
||||
|
||||
return MossRobotConfig(**kwargs)
|
||||
elif robot_type == "so100_leader":
|
||||
from .so100.config_so100_follower import SO100FollowerConfig
|
||||
|
||||
return SO100FollowerConfig(**kwargs)
|
||||
elif robot_type == "stretch":
|
||||
from .stretch3.configuration_stretch3 import Stretch3RobotConfig
|
||||
|
||||
return Stretch3RobotConfig(**kwargs)
|
||||
elif robot_type == "lekiwi":
|
||||
from .lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
||||
|
||||
return LeKiwiRobotConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||
|
||||
|
||||
def make_robot_from_config(config: RobotConfig):
|
||||
from .lekiwi.configuration_lekiwi import LeKiwiRobotConfig
|
||||
from .manipulator import ManipulatorRobotConfig
|
||||
|
||||
if isinstance(config, ManipulatorRobotConfig):
|
||||
from lerobot.common.robots.manipulator import ManipulatorRobot
|
||||
|
||||
return ManipulatorRobot(config)
|
||||
elif isinstance(config, LeKiwiRobotConfig):
|
||||
from lerobot.common.robots.mobile_manipulator import MobileManipulator
|
||||
|
||||
return MobileManipulator(config)
|
||||
else:
|
||||
from lerobot.common.robots.stretch3.robot_stretch3 import Stretch3Robot
|
||||
|
||||
return Stretch3Robot(config)
|
||||
|
||||
|
||||
def make_robot(robot_type: str, **kwargs) -> Robot:
|
||||
config = make_robot_config(robot_type, **kwargs)
|
||||
return make_robot_from_config(config)
|
||||
|
||||
|
||||
def ensure_safe_goal_position(
|
||||
goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float]
|
||||
) -> dict[str, float]:
|
||||
"""Caps relative action target magnitude for safety."""
|
||||
|
||||
if isinstance(max_relative_target, float):
|
||||
diff_cap = dict.fromkeys(goal_present_pos, max_relative_target)
|
||||
elif isinstance(max_relative_target, dict):
|
||||
if not set(goal_present_pos) == set(max_relative_target):
|
||||
raise ValueError("max_relative_target keys must match those of goal_present_pos.")
|
||||
diff_cap = max_relative_target
|
||||
else:
|
||||
raise TypeError(max_relative_target)
|
||||
|
||||
warnings_dict = {}
|
||||
safe_goal_positions = {}
|
||||
for key, (goal_pos, present_pos) in goal_present_pos.items():
|
||||
diff = goal_pos - present_pos
|
||||
max_diff = diff_cap[key]
|
||||
safe_diff = min(diff, max_diff)
|
||||
safe_diff = max(safe_diff, -max_diff)
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
safe_goal_positions[key] = safe_goal_pos
|
||||
if abs(safe_goal_pos - goal_pos) > 1e-4:
|
||||
warnings_dict[key] = {
|
||||
"original goal_pos": goal_pos,
|
||||
"safe goal_pos": safe_goal_pos,
|
||||
}
|
||||
|
||||
if warnings_dict:
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f"{pformat(warnings_dict, indent=4)}"
|
||||
)
|
||||
|
||||
return safe_goal_positions
|
||||
2
lerobot/common/robots/viperx/__init__.py
Normal file
2
lerobot/common/robots/viperx/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config_viperx import ViperXConfig
|
||||
from .viperx import ViperX
|
||||
31
lerobot/common/robots/viperx/config_viperx.py
Normal file
31
lerobot/common/robots/viperx/config_viperx.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.common.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("viperx")
|
||||
@dataclass
|
||||
class ViperXConfig(RobotConfig):
|
||||
port: str # Port to connect to the arm
|
||||
|
||||
disable_torque_on_disconnect: bool = True
|
||||
|
||||
# /!\ FOR SAFETY, READ THIS /!\
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
# For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default.
|
||||
# When you feel more confident with teleoperation or running the policy, you can extend
|
||||
# this safety limit and even removing it by setting it to `null`.
|
||||
# Also, everything is expected to work safely out-of-the-box, but we highly advise to
|
||||
# first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml),
|
||||
# then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully
|
||||
max_relative_target: int | None = 5
|
||||
|
||||
# cameras
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
# Troubleshooting: If one of your IntelRealSense cameras freeze during
|
||||
# data recording due to bandwidth limit, you might need to plug the camera
|
||||
# on another USB hub or PCIe card.
|
||||
229
lerobot/common/robots/viperx/viperx.py
Normal file
229
lerobot/common/robots/viperx/viperx.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Contains logic to instantiate a robot, read information from its motors and cameras,
|
||||
and send orders to its motors.
|
||||
"""
|
||||
# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated
|
||||
# calibration procedure, to make it easy for people to add their own robot.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.common.cameras.utils import make_cameras_from_configs
|
||||
from lerobot.common.constants import OBS_IMAGES, OBS_STATE
|
||||
from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.common.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
from .config_viperx import ViperXConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ViperX(Robot):
|
||||
"""
|
||||
[ViperX](https://www.trossenrobotics.com/viperx-300) developed by Trossen Robotics
|
||||
"""
|
||||
|
||||
config_class = ViperXConfig
|
||||
name = "viperx"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ViperXConfig,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.arm = DynamixelMotorsBus(
|
||||
port=self.config.port,
|
||||
motors={
|
||||
"waist": Motor(1, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"shoulder_shadow": Motor(3, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"elbow": Motor(4, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"elbow_shadow": Motor(5, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"forearm_roll": Motor(6, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_angle": Motor(7, "xm540-w270", MotorNormMode.RANGE_M100_100),
|
||||
"wrist_rotate": Motor(8, "xm430-w350", MotorNormMode.RANGE_M100_100),
|
||||
"gripper": Motor(9, "xm430-w350", MotorNormMode.RANGE_0_100),
|
||||
},
|
||||
)
|
||||
self.cameras = make_cameras_from_configs(config.cameras)
|
||||
|
||||
@property
|
||||
def state_feature(self) -> dict:
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": (len(self.arm),),
|
||||
"names": {"motors": list(self.arm.motors)},
|
||||
}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> dict:
|
||||
return self.state_feature
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict[str, dict]:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
# TODO(aliberts): add cam.is_connected for cam in self.cameras
|
||||
return self.arm.is_connected
|
||||
|
||||
def connect(self) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.arm.connect()
|
||||
if not self.is_calibrated:
|
||||
self.calibrate()
|
||||
|
||||
for cam in self.cameras.values():
|
||||
cam.connect()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch
|
||||
logger.info(f"\nRunning calibration of {self}")
|
||||
self.arm.disable_torque()
|
||||
for name in self.arm.names:
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
input("Move robot to the middle of its range of motion and press ENTER....")
|
||||
homing_offsets = self.arm.set_half_turn_homings()
|
||||
|
||||
full_turn_motors = ["shoulder_pan", "wrist_roll"]
|
||||
unknown_range_motors = [name for name in self.arm.names if name not in full_turn_motors]
|
||||
logger.info(
|
||||
f"Move all joints except {full_turn_motors} sequentially through their entire "
|
||||
"ranges of motion.\nRecording positions. Press ENTER to stop..."
|
||||
)
|
||||
range_mins, range_maxes = self.arm.record_ranges_of_motion(unknown_range_motors)
|
||||
for name in full_turn_motors:
|
||||
range_mins[name] = 0
|
||||
range_maxes[name] = 4095
|
||||
|
||||
self.calibration = {}
|
||||
for name, motor in self.arm.motors.items():
|
||||
self.calibration[name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=0,
|
||||
homing_offset=homing_offsets[name],
|
||||
range_min=range_mins[name],
|
||||
range_max=range_maxes[name],
|
||||
)
|
||||
|
||||
self.arm.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
logger.info(f"Calibration saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
with self.arm.torque_disabled():
|
||||
self.arm.configure_motors()
|
||||
|
||||
# Set secondary/shadow ID for shoulder and elbow. These joints have two motors.
|
||||
# As a result, if only one of them is required to move to a certain position,
|
||||
# the other will follow. This is to avoid breaking the motors.
|
||||
self.arm.write("Secondary_ID", "shoulder_shadow", 2)
|
||||
self.arm.write("Secondary_ID", "elbow_shadow", 4)
|
||||
|
||||
# Set a velocity limit of 131 as advised by Trossen Robotics
|
||||
# TODO(aliberts): remove as it's actually useless in position control
|
||||
self.arm.write("Velocity_Limit", 131)
|
||||
|
||||
# Use 'extended position mode' for all motors except gripper, because in joint mode the servos
|
||||
# can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling
|
||||
# the arm, you could end up with a servo with a position 0 or 4095 at a crucial point.
|
||||
# See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11
|
||||
for name in self.arm.names:
|
||||
if name != "gripper":
|
||||
self.arm.write("Operating_Mode", name, OperatingMode.EXTENDED_POSITION.value)
|
||||
|
||||
# Use 'position control current based' for follower gripper to be limited by the limit of the
|
||||
# current. It can grasp an object without forcing too much even tho, it's goal position is a
|
||||
# complete grasp (both gripper fingers are ordered to join and reach a touch).
|
||||
self.arm.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
"""The returned observations do not have a batch dimension."""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict[OBS_STATE] = self.arm.sync_read("Present_Position")
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read state: {dt_ms:.1f}ms")
|
||||
|
||||
# Capture images from cameras
|
||||
for cam_key, cam in self.cameras.items():
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
def send_action(self, action: dict[str, float]) -> dict[str, float]:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
The relative action magnitude may be clipped depending on the configuration parameter
|
||||
`max_relative_target`. In this case, the action sent differs from original action.
|
||||
Thus, this function always returns the action actually sent.
|
||||
|
||||
Args:
|
||||
action (dict[str, float]): The goal positions for the motors.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = action
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# /!\ Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
present_pos = self.arm.sync_read("Present_Position")
|
||||
goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()}
|
||||
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
|
||||
|
||||
# Send goal position to the arm
|
||||
self.arm.sync_write("Goal_Position", goal_pos)
|
||||
return goal_pos
|
||||
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.arm.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
4
lerobot/common/teleoperators/__init__.py
Normal file
4
lerobot/common/teleoperators/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .config import TeleoperatorConfig
|
||||
from .teleoperator import Teleoperator
|
||||
|
||||
__all__ = ["TeleoperatorConfig", "Teleoperator"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user