forked from tangger/lerobot
Compare commits
15 Commits
user/pepij
...
fix/lint_w
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e511e7eda5 | ||
|
|
f5ed3723f0 | ||
|
|
b104be0d04 | ||
|
|
f9e4a1f5c4 | ||
|
|
0eb56cec14 | ||
|
|
e59ef036e1 | ||
|
|
9b380eaf67 | ||
|
|
1187604ba0 | ||
|
|
5c6f2d2cd0 | ||
|
|
652fedf69c | ||
|
|
85214ec303 | ||
|
|
dffa5a18db | ||
|
|
301f152a34 | ||
|
|
0ed08c0b1f | ||
|
|
254bc707e7 |
@@ -73,7 +73,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/artifacts
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
2
.github/workflows/test-docker-build.yml
vendored
2
.github/workflows/test-docker-build.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
|
||||
- name: Get changed files
|
||||
id: changed-files
|
||||
uses: tj-actions/changed-files@3f54ebb830831fc121d3263c1857cfbdc310cdb9 #v42
|
||||
uses: tj-actions/changed-files@v44
|
||||
with:
|
||||
files: docker/**
|
||||
json: "true"
|
||||
|
||||
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -126,7 +126,7 @@ jobs:
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
- name: Install uv and python
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -78,7 +78,7 @@ pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
!tests/artifacts
|
||||
!tests/data
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
|
||||
@@ -12,17 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
exclude: "tests/artifacts/.*\\.safetensors$"
|
||||
exclude: ^(tests/data)
|
||||
default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
##### Meta #####
|
||||
- repo: meta
|
||||
hooks:
|
||||
- id: check-useless-excludes
|
||||
- id: check-hooks-apply
|
||||
|
||||
|
||||
##### Style / Misc. #####
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
@@ -35,37 +28,31 @@ repos:
|
||||
- id: check-toml
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.30.2
|
||||
rev: v1.30.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [--force-exclude]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.1
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.9.10
|
||||
rev: v0.9.9
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
|
||||
##### Security #####
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.24.0
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
- repo: https://github.com/woodruffw/zizmor-pre-commit
|
||||
rev: v1.4.1
|
||||
hooks:
|
||||
- id: zizmor
|
||||
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
|
||||
@@ -291,7 +291,7 @@ sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
```
|
||||
|
||||
Pull artifacts if they're not in [tests/artifacts](tests/artifacts)
|
||||
Pull artifacts if they're not in [tests/data](tests/data)
|
||||
```bash
|
||||
git lfs pull
|
||||
```
|
||||
|
||||
@@ -232,8 +232,8 @@ python lerobot/scripts/eval.py \
|
||||
--env.type=pusht \
|
||||
--eval.batch_size=10 \
|
||||
--eval.n_episodes=10 \
|
||||
--policy.use_amp=false \
|
||||
--policy.device=cuda
|
||||
--use_amp=false \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
Note: After training your own policy, you can re-evaluate the checkpoints with:
|
||||
|
||||
@@ -51,7 +51,7 @@ For a comprehensive list and documentation of these parameters, see the ffmpeg d
|
||||
### Decoding parameters
|
||||
**Decoder**
|
||||
We tested two video decoding backends from torchvision:
|
||||
- `pyav`
|
||||
- `pyav` (default)
|
||||
- `video_reader` (requires to build torchvision from source)
|
||||
|
||||
**Requested timestamps**
|
||||
|
||||
@@ -67,7 +67,8 @@ def parse_int_or_none(value) -> int | None:
|
||||
def check_datasets_formats(repo_ids: list) -> None:
|
||||
for repo_id in repo_ids:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
# TODO(Steven): Seems this API has changed
|
||||
if dataset.video:
|
||||
raise ValueError(
|
||||
f"Use only image dataset for running this benchmark. Video dataset provided: {repo_id}"
|
||||
)
|
||||
|
||||
@@ -454,8 +454,8 @@ Next, you'll need to calibrate your SO-100 robot to ensure that the leader and f
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
@@ -470,8 +470,8 @@ python lerobot/scripts/control_robot.py \
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
@@ -571,25 +571,18 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_so100_test \
|
||||
--job_name=act_so100_test \
|
||||
--policy.device=cuda \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
--config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
|
||||
@@ -366,8 +366,8 @@ Now we have to calibrate the leader arm and the follower arm. The wheel motors d
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/lekiwi/mobile_calib_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/lekiwi/mobile_calib_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration:
|
||||
@@ -385,8 +385,8 @@ If you have the **wired** LeKiwi version please run all commands including this
|
||||
### Calibrate leader arm
|
||||
Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script (on your laptop/pc) to launch manual calibration:
|
||||
@@ -416,22 +416,22 @@ python lerobot/scripts/control_robot.py \
|
||||
|
||||
You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below:
|
||||
| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) |
|
||||
| ---------- | ------------------ | ---------------------- |
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|------------|-------------------|-----------------------|
|
||||
| Fast | 0.4 | 90 |
|
||||
| Medium | 0.25 | 60 |
|
||||
| Slow | 0.1 | 30 |
|
||||
|
||||
|
||||
| Key | Action |
|
||||
| --- | -------------- |
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
| Key | Action |
|
||||
|------|--------------------------------|
|
||||
| W | Move forward |
|
||||
| A | Move left |
|
||||
| S | Move backward |
|
||||
| D | Move right |
|
||||
| Z | Turn left |
|
||||
| X | Turn right |
|
||||
| R | Increase speed |
|
||||
| F | Decrease speed |
|
||||
|
||||
> [!TIP]
|
||||
> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py).
|
||||
@@ -549,14 +549,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_lekiwi_test \
|
||||
--job_name=act_lekiwi_test \
|
||||
--policy.device=cuda \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`.
|
||||
|
||||
@@ -176,8 +176,8 @@ Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
@@ -192,8 +192,8 @@ python lerobot/scripts/control_robot.py \
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
@@ -293,14 +293,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_moss_test \
|
||||
--job_name=act_moss_test \
|
||||
--policy.device=cuda \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--device=cpu` (`--device=mps` respectively). However, be advised that the code executes much slower on cpu.
|
||||
|
||||
|
||||
## The training script
|
||||
|
||||
@@ -386,14 +386,14 @@ When you connect your robot for the first time, the [`ManipulatorRobot`](../lero
|
||||
|
||||
Here are the positions you'll move the follower arm to:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/koch/follower_zero.webp?raw=true" alt="Koch v1.1 follower arm zero position" title="Koch v1.1 follower arm zero position" style="width:100%;"> | <img src="../media/koch/follower_rotated.webp?raw=true" alt="Koch v1.1 follower arm rotated position" title="Koch v1.1 follower arm rotated position" style="width:100%;"> | <img src="../media/koch/follower_rest.webp?raw=true" alt="Koch v1.1 follower arm rest position" title="Koch v1.1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
And here are the corresponding positions for the leader arm:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/koch/leader_zero.webp?raw=true" alt="Koch v1.1 leader arm zero position" title="Koch v1.1 leader arm zero position" style="width:100%;"> | <img src="../media/koch/leader_rotated.webp?raw=true" alt="Koch v1.1 leader arm rotated position" title="Koch v1.1 leader arm rotated position" style="width:100%;"> | <img src="../media/koch/leader_rest.webp?raw=true" alt="Koch v1.1 leader arm rest position" title="Koch v1.1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details.
|
||||
@@ -898,14 +898,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_koch_test \
|
||||
--job_name=act_koch_test \
|
||||
--policy.device=cuda \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -135,14 +135,14 @@ python lerobot/scripts/train.py \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/act_aloha_test \
|
||||
--job_name=act_aloha_test \
|
||||
--policy.device=cuda \
|
||||
--device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -222,7 +222,7 @@ def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = T
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
repository_id = "lerobot/pusht"
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
@@ -230,13 +230,13 @@ if __name__ == "__main__":
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for mode in modes:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
data_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for available_mode in modes:
|
||||
if available_mode in ["image", "keypoints"]:
|
||||
repository_id += f"_{available_mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
main(data_dir, repo_id=repository_id, mode=available_mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id)
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from pprint import pformat
|
||||
|
||||
import torch
|
||||
|
||||
@@ -98,17 +96,17 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
# TODO(aliberts): add proper support for multi dataset
|
||||
# delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.dataset.video_backend,
|
||||
)
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
)
|
||||
# dataset = MultiLeRobotDataset(
|
||||
# cfg.dataset.repo_id,
|
||||
# # TODO(aliberts): add proper support for multi dataset
|
||||
# # delta_timestamps=delta_timestamps,
|
||||
# image_transforms=image_transforms,
|
||||
# video_backend=cfg.dataset.video_backend,
|
||||
# )
|
||||
# logging.info(
|
||||
# "Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
# f"{pformat(dataset.repo_id_to_index, indent=2)}"
|
||||
# )
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
for key in dataset.meta.camera_keys:
|
||||
|
||||
@@ -81,21 +81,21 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
def worker_thread_loop(task_queue: queue.Queue):
|
||||
while True:
|
||||
item = queue.get()
|
||||
item = task_queue.get()
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
task_queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
task_queue.task_done()
|
||||
|
||||
|
||||
def worker_process(queue: queue.Queue, num_threads: int):
|
||||
def worker_process(task_queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
||||
t = threading.Thread(target=worker_thread_loop, args=(task_queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
@@ -67,9 +67,8 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
get_safe_default_codec,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
@@ -88,6 +87,7 @@ class LeRobotDatasetMetadata:
|
||||
self.repo_id = repo_id
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.root = Path(root) if root is not None else HF_LEROBOT_HOME / repo_id
|
||||
self.stats = None
|
||||
|
||||
try:
|
||||
if force_cache_sync:
|
||||
@@ -103,10 +103,10 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
def load_metadata(self):
|
||||
self.info = load_info(self.root)
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
check_version_compatibility(self.repo_id, self.version, CODEBASE_VERSION)
|
||||
self.tasks, self.task_to_task_index = load_tasks(self.root)
|
||||
self.episodes = load_episodes(self.root)
|
||||
if self._version < packaging.version.parse("v2.1"):
|
||||
if self.version < packaging.version.parse("v2.1"):
|
||||
self.stats = load_stats(self.root)
|
||||
self.episodes_stats = backward_compatible_episodes_stats(self.stats, self.episodes)
|
||||
else:
|
||||
@@ -128,7 +128,7 @@ class LeRobotDatasetMetadata:
|
||||
)
|
||||
|
||||
@property
|
||||
def _version(self) -> packaging.version.Version:
|
||||
def version(self) -> packaging.version.Version:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return packaging.version.parse(self.info["codebase_version"])
|
||||
|
||||
@@ -322,8 +322,9 @@ class LeRobotDatasetMetadata:
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
"Some cameras in your %s robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks.",
|
||||
robot.robot_type,
|
||||
)
|
||||
elif features is None:
|
||||
raise ValueError(
|
||||
@@ -463,8 +464,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. Defaults to torchcodec when available int the platform; otherwise, defaults to 'pyav'.
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
@@ -474,7 +475,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self.video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self.video_backend = video_backend if video_backend else "pyav"
|
||||
self.delta_indices = None
|
||||
|
||||
# Unused attributes
|
||||
@@ -487,7 +488,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.meta = LeRobotDatasetMetadata(
|
||||
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
if self.episodes is not None and self.meta._version >= packaging.version.parse("v2.1"):
|
||||
if self.episodes is not None and self.meta.version >= packaging.version.parse("v2.1"):
|
||||
episodes_stats = [self.meta.episodes_stats[ep_idx] for ep_idx in self.episodes]
|
||||
self.stats = aggregate_stats(episodes_stats)
|
||||
|
||||
@@ -519,7 +520,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self,
|
||||
branch: str | None = None,
|
||||
tags: list | None = None,
|
||||
license: str | None = "apache-2.0",
|
||||
dataset_license: str | None = "apache-2.0",
|
||||
tag_version: bool = True,
|
||||
push_videos: bool = True,
|
||||
private: bool = False,
|
||||
@@ -562,7 +563,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
if not hub_api.file_exists(self.repo_id, REPOCARD_NAME, repo_type="dataset", revision=branch):
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=dataset_license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -708,7 +709,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(video_path, query_ts, self.tolerance_s, self.video_backend)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
@@ -841,6 +844,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
"""
|
||||
episode_buffer = None
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
@@ -1028,7 +1032,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.episode_data_index = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
|
||||
@@ -1085,8 +1089,9 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
"keys %s of %s were disabled as they are not contained in all the other datasets.",
|
||||
extra_keys,
|
||||
repo_id,
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compre
|
||||
# rechunk recompress
|
||||
group.move(name, tmp_key)
|
||||
old_arr = group[tmp_key]
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=old_arr,
|
||||
dest=group,
|
||||
name=name,
|
||||
@@ -192,7 +192,7 @@ class ReplayBuffer:
|
||||
else:
|
||||
root = zarr.group(store=store)
|
||||
# copy without recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists
|
||||
)
|
||||
data_group = root.create_group("data", overwrite=True)
|
||||
@@ -205,7 +205,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=src_store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
@@ -214,7 +214,7 @@ class ReplayBuffer:
|
||||
)
|
||||
else:
|
||||
# copy with recompression
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy(
|
||||
source=value,
|
||||
dest=data_group,
|
||||
name=key,
|
||||
@@ -275,7 +275,7 @@ class ReplayBuffer:
|
||||
compressors = {}
|
||||
if self.backend == "zarr":
|
||||
# recompression free copy
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path="/meta",
|
||||
@@ -297,7 +297,7 @@ class ReplayBuffer:
|
||||
if cks == value.chunks and cpr == value.compressor:
|
||||
# copy without recompression
|
||||
this_path = "/data/" + key
|
||||
n_copied, n_skipped, n_bytes_copied = zarr.copy_store(
|
||||
_n_copied, _n_skipped, _n_bytes_copied = zarr.copy_store(
|
||||
source=self.root.store,
|
||||
dest=store,
|
||||
source_path=this_path,
|
||||
|
||||
@@ -162,9 +162,9 @@ def download_raw(raw_dir: Path, repo_id: str):
|
||||
)
|
||||
raw_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging.info(f"Start downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Start downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
snapshot_download(repo_id, repo_type="dataset", local_dir=raw_dir)
|
||||
logging.info(f"Finish downloading from huggingface.co/{user_id} for {dataset_id}")
|
||||
logging.info("Finish downloading from huggingface.co/%s for %s", user_id, dataset_id)
|
||||
|
||||
|
||||
def download_all_raw_datasets(data_dir: Path | None = None):
|
||||
|
||||
@@ -72,7 +72,7 @@ def check_format(raw_dir) -> bool:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 2
|
||||
else:
|
||||
assert data[f"/observations/images/{camera}"].ndim == 4
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
_, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided."
|
||||
|
||||
|
||||
@@ -103,6 +103,7 @@ def load_from_raw(
|
||||
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
if "/observations/effort" in ep:
|
||||
|
||||
@@ -96,6 +96,7 @@ def from_raw_to_lerobot_format(
|
||||
if fps is None:
|
||||
fps = 30
|
||||
|
||||
# TODO(Steven): Is this meant to call cam_png_format.load_from_raw?
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
@@ -42,7 +42,9 @@ def check_format(raw_dir) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None):
|
||||
def load_from_raw(
|
||||
raw_dir: Path, videos_dir: Path, fps: int, _video: bool, _episodes: list[int] | None = None
|
||||
):
|
||||
# Load data stream that will be used as reference for the timestamps synchronization
|
||||
reference_files = list(raw_dir.glob("observation.images.cam_*.parquet"))
|
||||
if len(reference_files) == 0:
|
||||
|
||||
@@ -55,7 +55,7 @@ def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers
|
||||
|
||||
num_images = len(imgs_array)
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
_ = [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
|
||||
|
||||
|
||||
def get_default_encoding() -> dict:
|
||||
@@ -92,24 +92,23 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
# The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
# For instance, the following is a valid episode_index:
|
||||
# [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
#
|
||||
# Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
# ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
# {
|
||||
# "from": [0, 3, 7],
|
||||
# "to": [3, 7, 12]
|
||||
# }
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
idx = None
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
|
||||
@@ -23,6 +23,7 @@ from torchvision.transforms.v2 import Transform
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class RandomSubsetApply(Transform):
|
||||
"""Apply a random subset of N transformations from a list of transformations.
|
||||
|
||||
@@ -218,6 +219,7 @@ def make_transform_from_config(cfg: ImageTransformConfig):
|
||||
raise ValueError(f"Transform '{cfg.type}' is not valid.")
|
||||
|
||||
|
||||
# TODO(Steven): Missing transform() implementation
|
||||
class ImageTransforms(Transform):
|
||||
"""A class to compose image transforms based on configuration."""
|
||||
|
||||
|
||||
@@ -135,21 +135,21 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
|
||||
def embed_images(dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
ds_format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset = dataset.with_format(**ds_format)
|
||||
return dataset
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
with open(fpath, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
with open(fpath, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -300,7 +300,7 @@ def check_version_compatibility(
|
||||
if v_check.major < v_current.major and enforce_breaking_major:
|
||||
raise BackwardCompatibilityError(repo_id, v_check)
|
||||
elif v_check.minor < v_current.minor:
|
||||
logging.warning(V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
logging.warning("%s", V21_MESSAGE.format(repo_id=repo_id, version=v_check))
|
||||
|
||||
|
||||
def get_repo_versions(repo_id: str) -> list[packaging.version.Version]:
|
||||
@@ -348,7 +348,9 @@ def get_safe_version(repo_id: str, version: str | packaging.version.Version) ->
|
||||
if compatibles:
|
||||
return_version = max(compatibles)
|
||||
if return_version < target_version:
|
||||
logging.warning(f"Revision {version} for {repo_id} not found, using version v{return_version}")
|
||||
logging.warning(
|
||||
"Revision %s for %s not found, using version v%s", version, repo_id, return_version
|
||||
)
|
||||
return f"v{return_version}"
|
||||
|
||||
lower_major = [v for v in hub_versions if v.major < target_version.major]
|
||||
@@ -403,7 +405,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
if ft["dtype"] in ["image", "video"]:
|
||||
type = FeatureType.VISUAL
|
||||
feature_type = FeatureType.VISUAL
|
||||
if len(shape) != 3:
|
||||
raise ValueError(f"Number of dimensions of {key} != 3 (shape={shape})")
|
||||
|
||||
@@ -412,16 +414,16 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea
|
||||
if names[2] in ["channel", "channels"]: # (h, w, c) -> (c, h, w)
|
||||
shape = (shape[2], shape[0], shape[1])
|
||||
elif key == "observation.environment_state":
|
||||
type = FeatureType.ENV
|
||||
feature_type = FeatureType.ENV
|
||||
elif key.startswith("observation"):
|
||||
type = FeatureType.STATE
|
||||
feature_type = FeatureType.STATE
|
||||
elif key == "action":
|
||||
type = FeatureType.ACTION
|
||||
feature_type = FeatureType.ACTION
|
||||
else:
|
||||
continue
|
||||
|
||||
policy_features[key] = PolicyFeature(
|
||||
type=type,
|
||||
type=feature_type,
|
||||
shape=shape,
|
||||
)
|
||||
|
||||
|
||||
@@ -871,11 +871,11 @@ def batch_convert():
|
||||
try:
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
@@ -190,11 +190,11 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
@@ -213,7 +213,7 @@ def get_features_from_hf_dataset(
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
elif isinstance(ft, datasets.Sequence):
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
@@ -232,6 +232,8 @@ def get_features_from_hf_dataset(
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channels"]
|
||||
else:
|
||||
raise NotImplementedError(f"Feature type {ft._type} not supported.")
|
||||
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
@@ -358,9 +360,9 @@ def move_videos(
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
for v_dir in video_dirs:
|
||||
if (v_dir / video_file).is_file():
|
||||
video_path = v_dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
@@ -652,6 +654,7 @@ def main():
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = None
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def fix_dataset(repo_id: str) -> str:
|
||||
return f"{repo_id}: skipped (no diff)"
|
||||
|
||||
if diff_meta_parquet:
|
||||
logging.warning(f"In info.json not in parquet: {meta_features - parquet_features}")
|
||||
logging.warning("In info.json not in parquet: %s", meta_features - parquet_features)
|
||||
assert diff_meta_parquet == {"language_instruction"}
|
||||
lerobot_metadata.features.pop("language_instruction")
|
||||
write_info(lerobot_metadata.info, lerobot_metadata.root)
|
||||
@@ -79,7 +79,7 @@ def batch_fix():
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
logging.info(status)
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def batch_convert():
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
|
||||
with open(logfile, "a") as file:
|
||||
with open(logfile, "a", encoding="utf-8") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ V21 = "v2.1"
|
||||
|
||||
|
||||
class SuppressWarnings:
|
||||
def __init__(self):
|
||||
self.previous_level = None
|
||||
|
||||
def __enter__(self):
|
||||
self.previous_level = logging.getLogger().getEffectiveLevel()
|
||||
logging.getLogger().setLevel(logging.ERROR)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
@@ -30,46 +29,6 @@ from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decodes video frames using the specified backend.
|
||||
|
||||
Args:
|
||||
video_path (Path): Path to the video file.
|
||||
timestamps (list[float]): List of timestamps to extract frames.
|
||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Decoded frames.
|
||||
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise ValueError(f"Unsupported video backend: {backend}")
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
@@ -124,7 +83,7 @@ def decode_video_frames_torchvision(
|
||||
for frame in reader:
|
||||
current_ts = frame["pts"]
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"frame loaded at timestamp={current_ts:.4f}")
|
||||
logging.info("frame loaded at timestamp=%.4f", current_ts)
|
||||
loaded_frames.append(frame["data"])
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
@@ -159,7 +118,7 @@ def decode_video_frames_torchvision(
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
logging.info("closest_ts=%s", closest_ts)
|
||||
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
@@ -168,81 +127,6 @@ def decode_video_frames_torchvision(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_video_frames_torchcodec(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
device: str = "cpu",
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||
|
||||
Note: Setting device="cuda" outside the main process, e.g. in data loader workers, will lead to CUDA initialization errors.
|
||||
|
||||
Note: Video benefits from inter-frame compression. Instead of storing every frame individually,
|
||||
the encoder stores a reference frame (or a key frame) and subsequent frames as differences relative to
|
||||
that key frame. As a consequence, to access a requested frame, we need to load the preceding key frame,
|
||||
and all subsequent frames until reaching the requested frame. The number of key frames in a video
|
||||
can be adjusted during encoding to take into account decoding time and video size in bytes.
|
||||
"""
|
||||
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
from torchcodec.decoders import VideoDecoder
|
||||
else:
|
||||
raise ImportError("torchcodec is required but not available.")
|
||||
|
||||
# initialize video decoder
|
||||
decoder = VideoDecoder(video_path, device=device, seek_mode="approximate")
|
||||
loaded_frames = []
|
||||
loaded_ts = []
|
||||
# get metadata for frame information
|
||||
metadata = decoder.metadata
|
||||
average_fps = metadata.average_fps
|
||||
|
||||
# convert timestamps to frame indices
|
||||
frame_indices = [round(ts * average_fps) for ts in timestamps]
|
||||
|
||||
# retrieve frames based on indices
|
||||
frames_batch = decoder.get_frames_at(indices=frame_indices)
|
||||
|
||||
for frame, pts in zip(frames_batch.data, frames_batch.pts_seconds, strict=False):
|
||||
loaded_frames.append(frame)
|
||||
loaded_ts.append(pts.item())
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"Frame loaded at timestamp={pts:.4f}")
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts = torch.tensor(loaded_ts)
|
||||
|
||||
# compute distances between each query timestamp and loaded timestamps
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
closest_ts = loaded_ts[argmin_]
|
||||
|
||||
if log_loaded_timestamps:
|
||||
logging.info(f"{closest_ts=}")
|
||||
|
||||
# convert to float32 in [0,1] range (channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
@@ -343,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
@@ -379,7 +265,9 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
|
||||
@@ -32,7 +32,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def gym_kwargs(self) -> dict:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
return "adam"
|
||||
|
||||
@abc.abstractmethod
|
||||
def build(self) -> torch.optim.Optimizer:
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -140,7 +140,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -119,7 +119,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
|
||||
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
|
||||
# we are ensembling over.
|
||||
@@ -147,8 +149,9 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = [batch[key] for key in self.config.image_features]
|
||||
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
@@ -219,6 +222,8 @@ class ACTTemporalEnsembler:
|
||||
self.chunk_size = chunk_size
|
||||
self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size))
|
||||
self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0)
|
||||
self.ensembled_actions = None
|
||||
self.ensembled_actions_count = None
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@@ -410,10 +415,11 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if "observation.images" in batch:
|
||||
batch_size = batch["observation.images"][0].shape[0]
|
||||
else:
|
||||
batch_size = batch["observation.environment_state"].shape[0]
|
||||
batch_size = (
|
||||
batch["observation.images"]
|
||||
if "observation.images" in batch
|
||||
else batch["observation.environment_state"]
|
||||
).shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and "action" in batch:
|
||||
@@ -486,21 +492,20 @@ class ACT(nn.Module):
|
||||
all_cam_features = []
|
||||
all_cam_pos_embeds = []
|
||||
|
||||
# For a list of images, the H and W may vary but H*W is constant.
|
||||
for img in batch["observation.images"]:
|
||||
cam_features = self.backbone(img)["feature_map"]
|
||||
for cam_index in range(batch["observation.images"].shape[-4]):
|
||||
cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"]
|
||||
# TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
|
||||
# buffer
|
||||
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features)
|
||||
|
||||
# Rearrange features to (sequence, batch, dim).
|
||||
cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c")
|
||||
cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c")
|
||||
|
||||
cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w)
|
||||
all_cam_features.append(cam_features)
|
||||
all_cam_pos_embeds.append(cam_pos_embed)
|
||||
|
||||
encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0))
|
||||
encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0))
|
||||
# Concatenate camera observation feature maps and positional embeddings along the width dimension,
|
||||
# and move to (sequence, batch, dim).
|
||||
all_cam_features = torch.cat(all_cam_features, axis=-1)
|
||||
encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c"))
|
||||
all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1)
|
||||
encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c"))
|
||||
|
||||
# Stack all tokens along the sequence dimension.
|
||||
encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0)
|
||||
|
||||
@@ -162,7 +162,7 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -170,6 +170,7 @@ def _make_noise_scheduler(name: str, **kwargs: dict) -> DDPMScheduler | DDIMSche
|
||||
raise ValueError(f"Unsupported noise scheduler type {name}")
|
||||
|
||||
|
||||
# TODO(Steven): Missing forward() implementation
|
||||
class DiffusionModel(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
@@ -203,6 +204,7 @@ class DiffusionModel(nn.Module):
|
||||
)
|
||||
|
||||
if config.num_inference_steps is None:
|
||||
# TODO(Steven): Consider type check?
|
||||
self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps
|
||||
else:
|
||||
self.num_inference_steps = config.num_inference_steps
|
||||
@@ -333,7 +335,7 @@ class DiffusionModel(nn.Module):
|
||||
# Sample a random noising timestep for each item in the batch.
|
||||
timesteps = torch.randint(
|
||||
low=0,
|
||||
high=self.noise_scheduler.config.num_train_timesteps,
|
||||
high=self.noise_scheduler.config.num_train_timesteps, # TODO(Steven): Consider type check?
|
||||
size=(trajectory.shape[0],),
|
||||
device=trajectory.device,
|
||||
).long()
|
||||
|
||||
@@ -69,12 +69,12 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
min_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max_norm = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
"max": nn.Parameter(max, requires_grad=False),
|
||||
"min": nn.Parameter(min_norm, requires_grad=False),
|
||||
"max": nn.Parameter(max_norm, requires_grad=False),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -170,12 +170,12 @@ class Normalize(nn.Module):
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
batch[key] = (batch[key] - min_norm) / (max_norm - min_norm + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
@@ -243,12 +243,12 @@ class Unnormalize(nn.Module):
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
min_norm = buffer["min"]
|
||||
max_norm = buffer["max"]
|
||||
assert not torch.isinf(min_norm).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max_norm).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
batch[key] = batch[key] * (max_norm - min_norm) + min_norm
|
||||
else:
|
||||
raise ValueError(norm_mode)
|
||||
return batch
|
||||
|
||||
@@ -91,7 +91,7 @@ class PI0Config(PreTrainedConfig):
|
||||
super().__post_init__()
|
||||
|
||||
# TODO(Steven): Validate device and amp? in all policy configs?
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
|
||||
@@ -55,7 +55,7 @@ def main():
|
||||
with open(save_dir / "noise.pkl", "rb") as f:
|
||||
noise = pickle.load(f)
|
||||
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
|
||||
with open(ckpt_jax_dir / "assets/norm_stats.json", encoding="utf-8") as f:
|
||||
norm_stats = json.load(f)
|
||||
|
||||
# Override stats
|
||||
|
||||
@@ -318,7 +318,7 @@ def update_keys_with_prefix(d: dict, prefix: str) -> dict:
|
||||
return {f"{prefix}{key}": value for key, value in d.items()}
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
|
||||
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, _tokenizer_id: str, output_path: str):
|
||||
# Break down orbax ckpts - they are in OCDBT
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
|
||||
# process projection params
|
||||
@@ -432,6 +432,6 @@ if __name__ == "__main__":
|
||||
convert_pi0_checkpoint(
|
||||
checkpoint_dir=args.checkpoint_dir,
|
||||
precision=args.precision,
|
||||
tokenizer_id=args.tokenizer_hub_id,
|
||||
_tokenizer_id=args.tokenizer_hub_id,
|
||||
output_path=args.output_path,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from packaging.version import Version
|
||||
|
||||
# TODO(Steven): Consider settings this a dependency constraint
|
||||
if Version(torch.__version__) > Version("2.5.0"):
|
||||
# Ffex attention is only available from torch 2.5 onwards
|
||||
from torch.nn.attention.flex_attention import (
|
||||
@@ -121,7 +122,7 @@ def flex_attention_forward(
|
||||
)
|
||||
|
||||
# mask is applied inside the kernel, ideally more efficiently than score_mod.
|
||||
attn_output, attention_weights = flex_attention(
|
||||
attn_output, _attention_weights = flex_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
||||
@@ -162,7 +162,7 @@ class TDMPCConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of gaussian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
|
||||
@@ -88,6 +88,9 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
for param in self.model_target.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self._queues = None
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -108,7 +111,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
self._prev_mean = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -514,6 +517,7 @@ class TDMPCPolicy(PreTrainedPolicy):
|
||||
update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum)
|
||||
|
||||
|
||||
# TODO(Steven): forward implementation missing
|
||||
class TDMPCTOLD(nn.Module):
|
||||
"""Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC."""
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ class VQBeTConfig(PreTrainedConfig):
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation (not exhaustive)."""
|
||||
# Input validation (not exhaustive).
|
||||
if not self.vision_backbone.startswith("resnet"):
|
||||
raise ValueError(
|
||||
f"`vision_backbone` must be one of the ResNet variants. Got {self.vision_backbone}."
|
||||
|
||||
@@ -70,6 +70,8 @@ class VQBeTPolicy(PreTrainedPolicy):
|
||||
|
||||
self.vqbet = VQBeTModel(config)
|
||||
|
||||
self._queues = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
@@ -535,7 +537,7 @@ class VQBeTHead(nn.Module):
|
||||
cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers
|
||||
)
|
||||
cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1)
|
||||
NT, G, choices = cbet_probs.shape
|
||||
NT, _G, choices = cbet_probs.shape
|
||||
sampled_centers = einops.rearrange(
|
||||
torch.multinomial(cbet_probs.view(-1, choices), num_samples=1),
|
||||
"(NT G) 1 -> NT G",
|
||||
@@ -578,7 +580,7 @@ class VQBeTHead(nn.Module):
|
||||
"decoded_action": decoded_action,
|
||||
}
|
||||
|
||||
def loss_fn(self, pred, target, **kwargs):
|
||||
def loss_fn(self, pred, target, **_kwargs):
|
||||
"""
|
||||
for given ground truth action values (target), and prediction (pred) this function calculates the overall loss.
|
||||
|
||||
@@ -605,7 +607,7 @@ class VQBeTHead(nn.Module):
|
||||
# Figure out the loss for the actions.
|
||||
# First, we need to find the closest cluster center for each ground truth action.
|
||||
with torch.no_grad():
|
||||
state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
_state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G
|
||||
|
||||
# Now we can compute the loss.
|
||||
|
||||
@@ -762,6 +764,7 @@ def _replace_submodules(
|
||||
return root_module
|
||||
|
||||
|
||||
# TODO(Steven): Missing implementation of forward, is it maybe vqvae_forward?
|
||||
class VqVae(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -876,13 +879,13 @@ class FocalLoss(nn.Module):
|
||||
self.gamma = gamma
|
||||
self.size_average = size_average
|
||||
|
||||
def forward(self, input, target):
|
||||
if len(input.shape) == 3:
|
||||
N, T, _ = input.shape
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
def forward(self, forward_input, target):
|
||||
if len(forward_input.shape) == 3:
|
||||
N, T, _ = forward_input.shape
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(N, T, 1)).view(N, T)
|
||||
elif len(input.shape) == 2:
|
||||
logpt = F.log_softmax(input, dim=-1)
|
||||
elif len(forward_input.shape) == 2:
|
||||
logpt = F.log_softmax(forward_input, dim=-1)
|
||||
logpt = logpt.gather(-1, target.view(-1, 1)).view(-1)
|
||||
pt = logpt.exp()
|
||||
|
||||
|
||||
@@ -34,63 +34,58 @@ from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
"""
|
||||
This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
# This file is part of a VQ-BeT that utilizes code from the following repositories:
|
||||
#
|
||||
# - Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
#
|
||||
# - nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
# Original source: https://github.com/karpathy/nanoGPT
|
||||
#
|
||||
# We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||
|
||||
- Vector Quantize PyTorch code is licensed under the MIT License:
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- nanoGPT part is an adaptation of Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
|
||||
We also made some changes to the original code to adapt it to our needs. The changes are described in the code below.
|
||||
"""
|
||||
|
||||
"""
|
||||
This is a part for nanoGPT that utilizes code from the following repository:
|
||||
|
||||
- Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
Original source: https://github.com/karpathy/nanoGPT
|
||||
|
||||
- The nanoGPT code is licensed under the MIT License:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2022 Andrej Karpathy
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
- We've made some changes to the original code to adapt it to our needs.
|
||||
|
||||
Changed variable names:
|
||||
- n_head -> gpt_n_head
|
||||
- n_embd -> gpt_hidden_dim
|
||||
- block_size -> gpt_block_size
|
||||
- n_layer -> gpt_n_layer
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
- removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||
- changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||
- in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||
|
||||
"""
|
||||
# This is a part for nanoGPT that utilizes code from the following repository:
|
||||
#
|
||||
# - Andrej Karpathy's nanoGPT implementation in PyTorch.
|
||||
# Original source: https://github.com/karpathy/nanoGPT
|
||||
#
|
||||
# - The nanoGPT code is licensed under the MIT License:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2022 Andrej Karpathy
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
# - We've made some changes to the original code to adapt it to our needs.
|
||||
#
|
||||
# Changed variable names:
|
||||
# - n_head -> gpt_n_head
|
||||
# - n_embd -> gpt_hidden_dim
|
||||
# - block_size -> gpt_block_size
|
||||
# - n_layer -> gpt_n_layer
|
||||
#
|
||||
#
|
||||
# class GPT(nn.Module):
|
||||
# - removed unused functions `def generate`, `def estimate_mfu`, and `def from_pretrained`
|
||||
# - changed the `configure_optimizers` to `def configure_parameters` and made it to return only the parameters of the model: we use an external optimizer in our training loop.
|
||||
# - in the function `forward`, we removed target loss calculation parts, since it will be calculated in the training loop (after passing through bin prediction and offset prediction heads).
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
@@ -200,9 +195,9 @@ class GPT(nn.Module):
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
print("number of parameters: {:.2f}M".format(n_params / 1e6))
|
||||
|
||||
def forward(self, input, targets=None):
|
||||
device = input.device
|
||||
b, t, d = input.size()
|
||||
def forward(self, forward_input):
|
||||
device = forward_input.device
|
||||
_, t, _ = forward_input.size()
|
||||
assert t <= self.config.gpt_block_size, (
|
||||
f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}"
|
||||
)
|
||||
@@ -211,7 +206,7 @@ class GPT(nn.Module):
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
tok_emb = self.transformer.wte(forward_input) # token embeddings of shape (b, t, gpt_hidden_dim)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim)
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
@@ -285,51 +280,48 @@ class GPT(nn.Module):
|
||||
return decay, no_decay
|
||||
|
||||
|
||||
"""
|
||||
This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
|
||||
- Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
|
||||
- The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2020 Phil Wang
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
- We've made some changes to the original code to adapt it to our needs.
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
- added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||
This enables the user to save an indicator whether the codebook is frozen or not.
|
||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
This is to make the function name more descriptive.
|
||||
|
||||
class VectorQuantize(nn.Module):
|
||||
- removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||
These parameters are not used in the code.
|
||||
- changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
This is to make the function name more descriptive.
|
||||
|
||||
"""
|
||||
# This file is a part for Residual Vector Quantization that utilizes code from the following repository:
|
||||
#
|
||||
# - Phil Wang's vector-quantize-pytorch implementation in PyTorch.
|
||||
# Original source: https://github.com/lucidrains/vector-quantize-pytorch
|
||||
#
|
||||
# - The vector-quantize-pytorch code is licensed under the MIT License:
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
#
|
||||
# - We've made some changes to the original code to adapt it to our needs.
|
||||
#
|
||||
# class ResidualVQ(nn.Module):
|
||||
# - added `self.register_buffer('freeze_codebook', torch.tensor(False))` to the __init__ method:
|
||||
# This enables the user to save an indicator whether the codebook is frozen or not.
|
||||
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
# This is to make the function name more descriptive.
|
||||
#
|
||||
# class VectorQuantize(nn.Module):
|
||||
# - removed the `use_cosine_sim` and `layernorm_after_project_in` parameters from the __init__ method:
|
||||
# These parameters are not used in the code.
|
||||
# - changed the name of function `get_codes_from_indices` → `get_codebook_vector_from_indices`:
|
||||
# This is to make the function name more descriptive.
|
||||
|
||||
|
||||
class ResidualVQ(nn.Module):
|
||||
@@ -479,6 +471,9 @@ class ResidualVQ(nn.Module):
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and not return_loss
|
||||
|
||||
null_indices = None
|
||||
null_loss = None
|
||||
|
||||
# sample a layer index at which to dropout further residual quantization
|
||||
# also prepare null indices and loss
|
||||
|
||||
@@ -933,7 +928,7 @@ class VectorQuantize(nn.Module):
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
def noop(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def find_cameras(raise_when_empty=True, mock=False) -> list[dict]:
|
||||
connected to the computer.
|
||||
"""
|
||||
if mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
@@ -77,9 +77,9 @@ def save_image(img_array, serial_number, frame_index, images_dir):
|
||||
path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
logging.info(f"Saved image: {path}")
|
||||
logging.info("Saved image: %s", path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}")
|
||||
logging.error("Failed to save image for camera %s frame %s: %s", serial_number, frame_index, e)
|
||||
|
||||
|
||||
def save_images_from_cameras(
|
||||
@@ -100,7 +100,7 @@ def save_images_from_cameras(
|
||||
serial_numbers = [cam["serial_number"] for cam in camera_infos]
|
||||
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -114,7 +114,7 @@ def save_images_from_cameras(
|
||||
camera = IntelRealSenseCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -224,20 +224,9 @@ class IntelRealSenseCamera:
|
||||
self.serial_number = self.find_serial_number_from_name(config.name)
|
||||
else:
|
||||
self.serial_number = config.serial_number
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
@@ -253,10 +242,11 @@ class IntelRealSenseCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(alibets): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -287,26 +277,22 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_pyrealsense2 as rs
|
||||
import tests.mock_pyrealsense2 as rs
|
||||
else:
|
||||
import pyrealsense2 as rs
|
||||
|
||||
config = rs.config()
|
||||
config.enable_device(str(self.serial_number))
|
||||
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
if self.fps and self.width and self.height:
|
||||
# TODO(rcadene): can we set rgb8 directly?
|
||||
config.enable_stream(
|
||||
rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps
|
||||
)
|
||||
config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.color)
|
||||
|
||||
if self.use_depth:
|
||||
if self.fps and self.capture_width and self.capture_height:
|
||||
config.enable_stream(
|
||||
rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps
|
||||
)
|
||||
if self.fps and self.width and self.height:
|
||||
config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps)
|
||||
else:
|
||||
config.enable_stream(rs.stream.depth)
|
||||
|
||||
@@ -344,18 +330,18 @@ class IntelRealSenseCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.capture_width is not None and self.capture_width != actual_width:
|
||||
if self.width is not None and self.width != actual_width:
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
f"Can't set {self.width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.capture_height is not None and self.capture_height != actual_height:
|
||||
if self.height is not None and self.height != actual_height:
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
f"Can't set {self.height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
@@ -375,7 +361,7 @@ class IntelRealSenseCamera:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -401,7 +387,7 @@ class IntelRealSenseCamera:
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
if h != self.height or w != self.width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -423,7 +409,7 @@ class IntelRealSenseCamera:
|
||||
depth_map = np.asanyarray(depth_frame.get_data())
|
||||
|
||||
h, w = depth_map.shape
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
if h != self.height or w != self.width:
|
||||
raise OSError(
|
||||
f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
@@ -461,7 +447,7 @@ class IntelRealSenseCamera:
|
||||
num_tries += 1
|
||||
time.sleep(1 / self.fps)
|
||||
if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()):
|
||||
raise Exception(
|
||||
raise TimeoutError(
|
||||
"The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called."
|
||||
)
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ from lerobot.common.utils.utils import capture_timestamp_utc
|
||||
MAX_OPENCV_INDEX = 60
|
||||
|
||||
|
||||
def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
def find_cameras(max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]:
|
||||
cameras = []
|
||||
if platform.system() == "Linux":
|
||||
print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports")
|
||||
@@ -80,7 +80,7 @@ def _find_cameras(
|
||||
possible_camera_ids: list[int | str], raise_when_empty=False, mock=False
|
||||
) -> list[int | str]:
|
||||
if mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -144,8 +144,8 @@ def save_images_from_cameras(
|
||||
camera = OpenCVCamera(config)
|
||||
camera.connect()
|
||||
print(
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, "
|
||||
f"height={camera.capture_height}, color_mode={camera.color_mode})"
|
||||
f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.width}, "
|
||||
f"height={camera.height}, color_mode={camera.color_mode})"
|
||||
)
|
||||
cameras.append(camera)
|
||||
|
||||
@@ -244,19 +244,9 @@ class OpenCVCamera:
|
||||
else:
|
||||
raise ValueError(f"Please check the provided camera_index: {self.camera_index}")
|
||||
|
||||
# Store the raw (capture) resolution from the config.
|
||||
self.capture_width = config.width
|
||||
self.capture_height = config.height
|
||||
|
||||
# If rotated by ±90, swap width and height.
|
||||
if config.rotation in [-90, 90]:
|
||||
self.width = config.height
|
||||
self.height = config.width
|
||||
else:
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
@@ -269,10 +259,11 @@ class OpenCVCamera:
|
||||
self.logs = {}
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# TODO(aliberts): Do we keep original width/height or do we define them after rotation?
|
||||
self.rotation = None
|
||||
if config.rotation == -90:
|
||||
self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
|
||||
@@ -286,7 +277,7 @@ class OpenCVCamera:
|
||||
raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.")
|
||||
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
@@ -334,10 +325,10 @@ class OpenCVCamera:
|
||||
|
||||
if self.fps is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FPS, self.fps)
|
||||
if self.capture_width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width)
|
||||
if self.capture_height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height)
|
||||
if self.width is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.width)
|
||||
if self.height is not None:
|
||||
self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.height)
|
||||
|
||||
actual_fps = self.camera.get(cv2.CAP_PROP_FPS)
|
||||
actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH)
|
||||
@@ -349,22 +340,19 @@ class OpenCVCamera:
|
||||
raise OSError(
|
||||
f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}."
|
||||
)
|
||||
if self.capture_width is not None and not math.isclose(
|
||||
self.capture_width, actual_width, rel_tol=1e-3
|
||||
):
|
||||
if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}."
|
||||
)
|
||||
if self.capture_height is not None and not math.isclose(
|
||||
self.capture_height, actual_height, rel_tol=1e-3
|
||||
):
|
||||
if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3):
|
||||
raise OSError(
|
||||
f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}."
|
||||
)
|
||||
|
||||
self.fps = round(actual_fps)
|
||||
self.capture_width = round(actual_width)
|
||||
self.capture_height = round(actual_height)
|
||||
self.width = round(actual_width)
|
||||
self.height = round(actual_height)
|
||||
|
||||
self.is_connected = True
|
||||
|
||||
def read(self, temporary_color_mode: str | None = None) -> np.ndarray:
|
||||
@@ -398,14 +386,14 @@ class OpenCVCamera:
|
||||
# so we convert the image color from BGR to RGB.
|
||||
if requested_color_mode == "rgb":
|
||||
if self.mock:
|
||||
import tests.cameras.mock_cv2 as cv2
|
||||
import tests.mock_cv2 as cv2
|
||||
else:
|
||||
import cv2
|
||||
|
||||
color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = color_image.shape
|
||||
if h != self.capture_height or w != self.capture_width:
|
||||
if h != self.height or w != self.width:
|
||||
raise OSError(
|
||||
f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead."
|
||||
)
|
||||
|
||||
@@ -169,7 +169,8 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
# TODO(Steven): Similar function in feetch.py, should be moved to a common place.
|
||||
def convert_to_bytes(value, byte, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
@@ -177,16 +178,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
if byte == 1:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
elif byte == 2:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
elif byte == 4:
|
||||
data = [
|
||||
dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)),
|
||||
dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)),
|
||||
@@ -196,7 +197,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
f"{byte} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -228,9 +229,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
addr, byte = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
all_bytes.append(byte)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
@@ -332,7 +333,7 @@ class DynamixelMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -356,7 +357,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -576,6 +577,8 @@ class DynamixelMotorsBus:
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
@@ -596,10 +599,15 @@ class DynamixelMotorsBus:
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode '{calib_mode}'.")
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
"Auto-correct calibration of motor '%s' by shifting value by {abs(factor)} full turns, "
|
||||
"from '%s' to '%s'.",
|
||||
name,
|
||||
out_of_range_str,
|
||||
in_range_str,
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
@@ -646,7 +654,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -656,8 +664,8 @@ class DynamixelMotorsBus:
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
@@ -674,7 +682,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
value = group.getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
@@ -691,7 +699,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -709,13 +717,13 @@ class DynamixelMotorsBus:
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = dxl.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
@@ -733,7 +741,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
@@ -757,7 +765,7 @@ class DynamixelMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -767,10 +775,10 @@ class DynamixelMotorsBus:
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
@@ -793,7 +801,7 @@ class DynamixelMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_dynamixel_sdk as dxl
|
||||
import tests.mock_dynamixel_sdk as dxl
|
||||
else:
|
||||
import dynamixel_sdk as dxl
|
||||
|
||||
@@ -821,17 +829,17 @@ class DynamixelMotorsBus:
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = dxl.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -148,7 +148,7 @@ def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str
|
||||
return steps
|
||||
|
||||
|
||||
def convert_to_bytes(value, bytes, mock=False):
|
||||
def convert_to_bytes(value, byte, mock=False):
|
||||
if mock:
|
||||
return value
|
||||
|
||||
@@ -156,16 +156,16 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
|
||||
# Note: No need to convert back into unsigned int, since this byte preprocessing
|
||||
# already handles it for us.
|
||||
if bytes == 1:
|
||||
if byte == 1:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 2:
|
||||
elif byte == 2:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
]
|
||||
elif bytes == 4:
|
||||
elif byte == 4:
|
||||
data = [
|
||||
scs.SCS_LOBYTE(scs.SCS_LOWORD(value)),
|
||||
scs.SCS_HIBYTE(scs.SCS_LOWORD(value)),
|
||||
@@ -175,7 +175,7 @@ def convert_to_bytes(value, bytes, mock=False):
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but "
|
||||
f"{bytes} is provided instead."
|
||||
f"{byte} is provided instead."
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -207,9 +207,9 @@ def assert_same_address(model_ctrl_table, motor_models, data_name):
|
||||
all_addr = []
|
||||
all_bytes = []
|
||||
for model in motor_models:
|
||||
addr, bytes = model_ctrl_table[model][data_name]
|
||||
addr, byte = model_ctrl_table[model][data_name]
|
||||
all_addr.append(addr)
|
||||
all_bytes.append(bytes)
|
||||
all_bytes.append(byte)
|
||||
|
||||
if len(set(all_addr)) != 1:
|
||||
raise NotImplementedError(
|
||||
@@ -313,7 +313,7 @@ class FeetechMotorsBus:
|
||||
)
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -337,7 +337,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def reconnect(self):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -557,6 +557,8 @@ class FeetechMotorsBus:
|
||||
# (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution
|
||||
low_factor = (start_pos - values[i]) / resolution
|
||||
upp_factor = (end_pos - values[i]) / resolution
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||
|
||||
if not in_range:
|
||||
# Get first integer between the two bounds
|
||||
@@ -577,10 +579,16 @@ class FeetechMotorsBus:
|
||||
elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR:
|
||||
out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %"
|
||||
else:
|
||||
raise ValueError(f"Unknown calibration mode {calib_mode}")
|
||||
|
||||
logging.warning(
|
||||
f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, "
|
||||
f"from '{out_of_range_str}' to '{in_range_str}'."
|
||||
"Auto-correct calibration of motor '%s' by shifting value by %s full turns, "
|
||||
"from '%s' to '%s'.",
|
||||
name,
|
||||
abs(factor),
|
||||
out_of_range_str,
|
||||
in_range_str,
|
||||
)
|
||||
|
||||
# A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096.
|
||||
@@ -664,7 +672,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -674,8 +682,8 @@ class FeetechMotorsBus:
|
||||
motor_ids = [motor_ids]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, self.motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx in motor_ids:
|
||||
group.addParam(idx)
|
||||
|
||||
@@ -692,7 +700,7 @@ class FeetechMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = group.getData(idx, addr, bytes)
|
||||
value = group.getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
if return_list:
|
||||
@@ -702,7 +710,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def read(self, data_name, motor_names: str | list[str] | None = None):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -727,7 +735,7 @@ class FeetechMotorsBus:
|
||||
models.append(model)
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
if data_name not in self.group_readers:
|
||||
@@ -737,7 +745,7 @@ class FeetechMotorsBus:
|
||||
|
||||
# create new group reader
|
||||
self.group_readers[group_key] = scs.GroupSyncRead(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
for idx in motor_ids:
|
||||
self.group_readers[group_key].addParam(idx)
|
||||
@@ -755,7 +763,7 @@ class FeetechMotorsBus:
|
||||
|
||||
values = []
|
||||
for idx in motor_ids:
|
||||
value = self.group_readers[group_key].getData(idx, addr, bytes)
|
||||
value = self.group_readers[group_key].getData(idx, addr, byte)
|
||||
values.append(value)
|
||||
|
||||
values = np.array(values)
|
||||
@@ -782,7 +790,7 @@ class FeetechMotorsBus:
|
||||
|
||||
def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY):
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -792,10 +800,10 @@ class FeetechMotorsBus:
|
||||
values = [values]
|
||||
|
||||
assert_same_address(self.model_ctrl_table, motor_models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes)
|
||||
addr, byte = self.model_ctrl_table[motor_models[0]][data_name]
|
||||
group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, byte)
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
group.addParam(idx, data)
|
||||
|
||||
for _ in range(num_retry):
|
||||
@@ -818,7 +826,7 @@ class FeetechMotorsBus:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
if self.mock:
|
||||
import tests.motors.mock_scservo_sdk as scs
|
||||
import tests.mock_scservo_sdk as scs
|
||||
else:
|
||||
import scservo_sdk as scs
|
||||
|
||||
@@ -846,17 +854,17 @@ class FeetechMotorsBus:
|
||||
values = values.tolist()
|
||||
|
||||
assert_same_address(self.model_ctrl_table, models, data_name)
|
||||
addr, bytes = self.model_ctrl_table[model][data_name]
|
||||
addr, byte = self.model_ctrl_table[model][data_name]
|
||||
group_key = get_group_sync_key(data_name, motor_names)
|
||||
|
||||
init_group = data_name not in self.group_readers
|
||||
if init_group:
|
||||
self.group_writers[group_key] = scs.GroupSyncWrite(
|
||||
self.port_handler, self.packet_handler, addr, bytes
|
||||
self.port_handler, self.packet_handler, addr, byte
|
||||
)
|
||||
|
||||
for idx, value in zip(motor_ids, values, strict=True):
|
||||
data = convert_to_bytes(value, bytes, self.mock)
|
||||
data = convert_to_bytes(value, byte, self.mock)
|
||||
if init_group:
|
||||
self.group_writers[group_key].addParam(idx, data)
|
||||
else:
|
||||
|
||||
@@ -95,6 +95,8 @@ def move_to_calibrate(
|
||||
while_move_hook=None,
|
||||
):
|
||||
initial_pos = arm.read("Present_Position", motor_name)
|
||||
p_present_pos = None
|
||||
n_present_pos = None
|
||||
|
||||
if positive_first:
|
||||
p_present_pos = move_until_block(
|
||||
@@ -196,7 +198,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex")
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80)
|
||||
|
||||
def in_between_move_hook():
|
||||
def in_between_move_hook_elbow():
|
||||
nonlocal arm, calib
|
||||
time.sleep(2)
|
||||
ef_pos = arm.read("Present_Position", "elbow_flex")
|
||||
@@ -207,14 +209,14 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
|
||||
print("Calibrate elbow_flex")
|
||||
calib["elbow_flex"] = move_to_calibrate(
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook
|
||||
arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook_elbow
|
||||
)
|
||||
calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024)
|
||||
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex")
|
||||
time.sleep(1)
|
||||
|
||||
def in_between_move_hook():
|
||||
def in_between_move_hook_shoulder():
|
||||
nonlocal arm, calib
|
||||
arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex")
|
||||
|
||||
@@ -224,7 +226,7 @@ def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: st
|
||||
"shoulder_lift",
|
||||
invert_drive_mode=True,
|
||||
positive_first=False,
|
||||
in_between_move_hook=in_between_move_hook,
|
||||
in_between_move_hook=in_between_move_hook_shoulder,
|
||||
)
|
||||
# add an 30 steps as offset to align with body
|
||||
calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50)
|
||||
|
||||
@@ -67,14 +67,14 @@ def calibrate_follower_arm(motors_bus, calib_dir_str):
|
||||
return
|
||||
|
||||
if calib_file.exists():
|
||||
with open(calib_file) as f:
|
||||
with open(calib_file, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
print(f"[INFO] Loaded calibration from {calib_file}")
|
||||
else:
|
||||
print("[INFO] Calibration file not found. Running manual calibration...")
|
||||
calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower")
|
||||
print(f"[INFO] Calibration complete. Saving to {calib_file}")
|
||||
with open(calib_file, "w") as f:
|
||||
with open(calib_file, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
try:
|
||||
motors_bus.set_calibration(calibration)
|
||||
|
||||
@@ -47,8 +47,10 @@ def ensure_safe_goal_position(
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.warning(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
" requested relative goal position target: %s\n"
|
||||
" clamped relative goal position target: %s",
|
||||
diff,
|
||||
safe_diff,
|
||||
)
|
||||
|
||||
return safe_goal_pos
|
||||
@@ -245,6 +247,8 @@ class ManipulatorRobot:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss", "lekiwi"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
else:
|
||||
raise NotImplementedError(f"Robot type {self.robot_type} is not supported")
|
||||
|
||||
# We assume that at connection time, arms are in a rest position, and torque can
|
||||
# be safely disabled to run calibration and/or set robot preset configurations.
|
||||
@@ -302,7 +306,7 @@ class ManipulatorRobot:
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
with open(arm_calib_path, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
@@ -322,7 +326,7 @@ class ManipulatorRobot:
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
|
||||
@@ -262,14 +262,14 @@ class MobileManipulator:
|
||||
arm_calib_path = self.calibration_dir / f"{arm_id}.json"
|
||||
|
||||
if arm_calib_path.exists():
|
||||
with open(arm_calib_path) as f:
|
||||
with open(arm_calib_path, encoding="utf-8") as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(arm_calib_path, "w") as f:
|
||||
with open(arm_calib_path, "w", encoding="utf-8") as f:
|
||||
json.dump(calibration, f)
|
||||
|
||||
return calibration
|
||||
@@ -372,6 +372,7 @@ class MobileManipulator:
|
||||
|
||||
present_speed = self.last_present_speed
|
||||
|
||||
# TODO(Steven): [WARN] Plenty of general exceptions
|
||||
except Exception as e:
|
||||
print(f"[DEBUG] Error decoding video message: {e}")
|
||||
# If decode fails, fall back to old data
|
||||
|
||||
@@ -68,9 +68,9 @@ class TimeBenchmark(ContextDecorator):
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
def __init__(self, print_time=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
self.print_time = print_time
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
|
||||
@@ -46,7 +46,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b
|
||||
else:
|
||||
# For packages other than "torch", don't attempt the fallback and set as not available
|
||||
package_exists = False
|
||||
logging.debug(f"Detected {pkg_name} version: {package_version}")
|
||||
logging.debug("Detected %s version: %s", {pkg_name}, package_version)
|
||||
if return_version:
|
||||
return package_exists, package_version
|
||||
else:
|
||||
|
||||
@@ -27,6 +27,8 @@ class AverageMeter:
|
||||
def __init__(self, name: str, fmt: str = ":f"):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.val = 0.0
|
||||
self.avg = 0.0
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
|
||||
case _:
|
||||
device = torch.device(try_device)
|
||||
if log:
|
||||
logging.warning(f"Using custom {try_device} device.")
|
||||
logging.warning("Using custom %s device.", try_device)
|
||||
|
||||
return device
|
||||
|
||||
|
||||
@@ -69,13 +69,7 @@ class WandBLogger:
|
||||
os.environ["WANDB_SILENT"] = "True"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = (
|
||||
cfg.wandb.run_id
|
||||
if cfg.wandb.run_id
|
||||
else get_wandb_run_id_from_filesystem(self.log_dir)
|
||||
if cfg.resume
|
||||
else None
|
||||
)
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.log_dir) if cfg.resume else None
|
||||
wandb.init(
|
||||
id=wandb_run_id,
|
||||
project=self.cfg.project,
|
||||
@@ -92,7 +86,7 @@ class WandBLogger:
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
logging.info("Track this run --> %s", colored(wandb.run.get_url(), "yellow", attrs=["bold"]))
|
||||
self._wandb = wandb
|
||||
|
||||
def log_policy(self, checkpoint_dir: Path):
|
||||
@@ -114,7 +108,7 @@ class WandBLogger:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
'WandB logging of key "%s" was ignored as its type is not handled by this wrapper.', k
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
@@ -20,7 +20,6 @@ from lerobot.common import (
|
||||
policies, # noqa: F401
|
||||
)
|
||||
from lerobot.common.datasets.transforms import ImageTransformsConfig
|
||||
from lerobot.common.datasets.video_utils import get_safe_default_codec
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -36,7 +35,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = "pyav"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -47,7 +46,6 @@ class WandBConfig:
|
||||
project: str = "lerobot"
|
||||
entity: str | None = None
|
||||
notes: str | None = None
|
||||
run_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,9 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
import sys
|
||||
from argparse import ArgumentError
|
||||
from functools import wraps
|
||||
@@ -25,7 +23,6 @@ import draccus
|
||||
from lerobot.common.utils.utils import has_method
|
||||
|
||||
PATH_KEY = "path"
|
||||
PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path"
|
||||
draccus.set_config_type("json")
|
||||
|
||||
|
||||
@@ -61,86 +58,6 @@ def parse_arg(arg_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def parse_plugin_args(plugin_arg_suffix: str, args: Sequence[str]) -> dict:
|
||||
"""Parse plugin-related arguments from command-line arguments.
|
||||
|
||||
This function extracts arguments from command-line arguments that match a specified suffix pattern.
|
||||
It processes arguments in the format '--key=value' and returns them as a dictionary.
|
||||
|
||||
Args:
|
||||
plugin_arg_suffix (str): The suffix to identify plugin-related arguments.
|
||||
cli_args (Sequence[str]): A sequence of command-line arguments to parse.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the parsed plugin arguments where:
|
||||
- Keys are the argument names (with '--' prefix removed if present)
|
||||
- Values are the corresponding argument values
|
||||
|
||||
Example:
|
||||
>>> args = ['--env.discover_packages_path=my_package',
|
||||
... '--other_arg=value']
|
||||
>>> parse_plugin_args('discover_packages_path', args)
|
||||
{'env.discover_packages_path': 'my_package'}
|
||||
"""
|
||||
plugin_args = {}
|
||||
for arg in args:
|
||||
if "=" in arg and plugin_arg_suffix in arg:
|
||||
key, value = arg.split("=", 1)
|
||||
# Remove leading '--' if present
|
||||
if key.startswith("--"):
|
||||
key = key[2:]
|
||||
plugin_args[key] = value
|
||||
return plugin_args
|
||||
|
||||
|
||||
class PluginLoadError(Exception):
|
||||
"""Raised when a plugin fails to load."""
|
||||
|
||||
|
||||
def load_plugin(plugin_path: str) -> None:
|
||||
"""Load and initialize a plugin from a given Python package path.
|
||||
|
||||
This function attempts to load a plugin by importing its package and any submodules.
|
||||
Plugin registration is expected to happen during package initialization, i.e. when
|
||||
the package is imported the gym environment should be registered and the config classes
|
||||
registered with their parents using the `register_subclass` decorator.
|
||||
|
||||
Args:
|
||||
plugin_path (str): The Python package path to the plugin (e.g. "mypackage.plugins.myplugin")
|
||||
|
||||
Raises:
|
||||
PluginLoadError: If the plugin cannot be loaded due to import errors or if the package path is invalid.
|
||||
|
||||
Examples:
|
||||
>>> load_plugin("external_plugin.core") # Loads plugin from external package
|
||||
|
||||
Notes:
|
||||
- The plugin package should handle its own registration during import
|
||||
- All submodules in the plugin package will be imported
|
||||
- Implementation follows the plugin discovery pattern from Python packaging guidelines
|
||||
|
||||
See Also:
|
||||
https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/
|
||||
"""
|
||||
try:
|
||||
package_module = importlib.import_module(plugin_path, __package__)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
raise PluginLoadError(
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
def iter_namespace(ns_pkg):
|
||||
return pkgutil.iter_modules(ns_pkg.__path__, ns_pkg.__name__ + ".")
|
||||
|
||||
try:
|
||||
for _finder, pkg_name, _ispkg in iter_namespace(package_module):
|
||||
importlib.import_module(pkg_name)
|
||||
except ImportError as e:
|
||||
raise PluginLoadError(
|
||||
f"Failed to load plugin '{plugin_path}'. Verify the path and installation: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_path_arg(field_name: str, args: Sequence[str] | None = None) -> str | None:
|
||||
return parse_arg(f"{field_name}.{PATH_KEY}", args)
|
||||
|
||||
@@ -188,13 +105,10 @@ def filter_path_args(fields_to_filter: str | list[str], args: Sequence[str] | No
|
||||
|
||||
def wrap(config_path: Path | None = None):
|
||||
"""
|
||||
HACK: Similar to draccus.wrap but does three additional things:
|
||||
HACK: Similar to draccus.wrap but does two additional things:
|
||||
- Will remove '.path' arguments from CLI in order to process them later on.
|
||||
- If a 'config_path' is passed and the main config class has a 'from_pretrained' method, will
|
||||
initialize it from there to allow to fetch configs from the hub directly
|
||||
- Will load plugins specified in the CLI arguments. These plugins will typically register
|
||||
their own subclasses of config classes, so that draccus can find the right class to instantiate
|
||||
from the CLI '.type' arguments
|
||||
"""
|
||||
|
||||
def wrapper_outer(fn):
|
||||
@@ -207,14 +121,6 @@ def wrap(config_path: Path | None = None):
|
||||
args = args[1:]
|
||||
else:
|
||||
cli_args = sys.argv[1:]
|
||||
plugin_args = parse_plugin_args(PLUGIN_DISCOVERY_SUFFIX, cli_args)
|
||||
for plugin_cli_arg, plugin_path in plugin_args.items():
|
||||
try:
|
||||
load_plugin(plugin_path)
|
||||
except PluginLoadError as e:
|
||||
# add the relevant CLI arg to the error message
|
||||
raise PluginLoadError(f"{e}\nFailed plugin CLI Arg: {plugin_cli_arg}") from e
|
||||
cli_args = filter_arg(plugin_cli_arg, cli_args)
|
||||
config_path_cli = parse_arg("config_path", cli_args)
|
||||
if has_method(argtype, "__get_path_fields__"):
|
||||
path_fields = argtype.__get_path_fields__()
|
||||
|
||||
@@ -64,13 +64,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
logging.warning("Device '%s' is not available. Switching to '%s'.", self.device, auto_device)
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
"Automatic Mixed Precision (amp) is not available on device '%s'. Deactivating AMP.",
|
||||
self.device,
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@@ -78,15 +79,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractproperty
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -128,7 +132,7 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with open(save_directory / CONFIG_NAME, "w", encoding="utf-8") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -79,9 +79,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
# The entire train config is already loaded, we just need to get the checkpoint dir
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
raise ValueError("A config_path is expected when resuming a run.")
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
@@ -125,7 +123,10 @@ class TrainPipelineConfig(HubMixin):
|
||||
return draccus.encode(self)
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / TRAIN_CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
with (
|
||||
open(save_directory / TRAIN_CONFIG_NAME, "w", encoding="utf-8") as f,
|
||||
draccus.config_type("json"),
|
||||
):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -90,6 +90,7 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des):
|
||||
print("Scanning all baudrates and motor indices")
|
||||
all_baudrates = set(series_baudrate_table.values())
|
||||
motor_index = -1 # Set the motor index to an out-of-range value.
|
||||
baudrate = None
|
||||
|
||||
for baudrate in all_baudrates:
|
||||
motor_bus.set_bus_baudrate(baudrate)
|
||||
|
||||
@@ -81,6 +81,7 @@ This might require a sudo permission to allow your terminal to monitor keyboard
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
"""
|
||||
|
||||
# TODO(Steven): This script should be updated to use the new robot API and the new dataset API.
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
@@ -59,7 +59,7 @@ np_version = np.__version__ if HAS_NP else "N/A"
|
||||
|
||||
torch_version = torch.__version__ if HAS_TORCH else "N/A"
|
||||
torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A"
|
||||
cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
cuda_version = torch.version.cuda if HAS_TORCH and torch.version.cuda is not None else "N/A"
|
||||
|
||||
|
||||
# TODO(aliberts): refactor into an actual command `lerobot env`
|
||||
|
||||
@@ -259,6 +259,10 @@ def eval_policy(
|
||||
threads = [] # for video saving threads
|
||||
n_episodes_rendered = 0 # for saving the correct number of videos
|
||||
|
||||
video_paths: list[str] = [] # max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = [] # max_episodes_rendered > 0
|
||||
episode_data: dict | None = None # return_episode_data == True
|
||||
|
||||
# Callback for visualization.
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
# noqa: B023
|
||||
@@ -271,19 +275,11 @@ def eval_policy(
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
if return_episode_data:
|
||||
episode_data: dict | None = None
|
||||
|
||||
# we dont want progress bar when we use slurm, since it clutters the logs
|
||||
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
|
||||
for batch_ix in progbar:
|
||||
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
|
||||
# step.
|
||||
if max_episodes_rendered > 0:
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
if start_seed is None:
|
||||
seeds = None
|
||||
@@ -320,13 +316,19 @@ def eval_policy(
|
||||
else:
|
||||
all_seeds.append(None)
|
||||
|
||||
# FIXME: episode_data is either None or it doesn't exist
|
||||
if return_episode_data:
|
||||
if episode_data is None:
|
||||
start_data_index = 0
|
||||
elif isinstance(episode_data, dict):
|
||||
start_data_index = episode_data["index"][-1].item() + 1
|
||||
else:
|
||||
start_data_index = 0
|
||||
|
||||
this_episode_data = _compile_episode_data(
|
||||
rollout_data,
|
||||
done_indices,
|
||||
start_episode_index=batch_ix * env.num_envs,
|
||||
start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)),
|
||||
start_data_index=start_data_index,
|
||||
fps=env.unwrapped.metadata["render_fps"],
|
||||
)
|
||||
if episode_data is None:
|
||||
@@ -453,6 +455,7 @@ def _compile_episode_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
# TODO(Steven): [WARN] Redefining built-in 'eval'
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
@@ -489,7 +492,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
print(info["aggregated"])
|
||||
|
||||
# Save info
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
env.close()
|
||||
|
||||
@@ -53,6 +53,7 @@ import torch
|
||||
from huggingface_hub import HfApi
|
||||
from safetensors.torch import save_file
|
||||
|
||||
# TODO(Steven): #711 Broke this
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id
|
||||
@@ -89,7 +90,7 @@ def save_meta_data(
|
||||
|
||||
# save info
|
||||
info_path = meta_data_dir / "info.json"
|
||||
with open(str(info_path), "w") as f:
|
||||
with open(str(info_path), "w", encoding="utf-8") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
# save stats
|
||||
@@ -120,11 +121,11 @@ def push_dataset_card_to_hub(
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
license: str = "apache-2.0",
|
||||
dataset_license: str = "apache-2.0",
|
||||
**card_kwargs,
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
||||
card = create_lerobot_dataset_card(tags=tags, license=dataset_license, **card_kwargs)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
@@ -213,6 +214,7 @@ def push_dataset_to_hub(
|
||||
encoding,
|
||||
)
|
||||
|
||||
# TODO(Steven): This doesn't seem to exist, maybe it was removed/changed recently?
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
|
||||
@@ -155,12 +155,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
logging.info("cfg.env.task=%s", cfg.env.task)
|
||||
logging.info("cfg.steps=%s (%s)", cfg.steps, format_big_number(cfg.steps))
|
||||
logging.info("dataset.num_frames=%s (%s)", dataset.num_frames, format_big_number(dataset.num_frames))
|
||||
logging.info("dataset.num_episodes=%s", dataset.num_episodes)
|
||||
logging.info(
|
||||
"num_learnable_params=%s (%s)", num_learnable_params, format_big_number(num_learnable_params)
|
||||
)
|
||||
logging.info("num_total_params=%s (%s)", num_total_params, format_big_number(num_total_params))
|
||||
|
||||
# create dataloader for offline training
|
||||
if hasattr(cfg.policy, "drop_n_last_frames"):
|
||||
@@ -238,7 +240,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
logging.info("Checkpoint policy after step %s", step)
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
@@ -247,7 +249,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.env and is_eval_step:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
logging.info("Eval policy at step %s", step)
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
|
||||
@@ -265,25 +265,13 @@ def main():
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance-s",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help=(
|
||||
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||
"If not given, defaults to 1e-4."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ def run_server(
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
str(dataset.meta._version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
str(dataset.meta.version) if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
@@ -358,7 +358,7 @@ def visualize_dataset_html(
|
||||
if force_override:
|
||||
shutil.rmtree(output_dir)
|
||||
else:
|
||||
logging.info(f"Output directory already exists. Loading from it: '{output_dir}'")
|
||||
logging.info("Output directory already exists. Loading from it: '%s'", {output_dir})
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -446,31 +446,15 @@ def main():
|
||||
help="Delete the output directory if it exists already.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tolerance-s",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help=(
|
||||
"Tolerance in seconds used to ensure data timestamps respect the dataset fps value"
|
||||
"This is argument passed to the constructor of LeRobotDataset and maps to its tolerance_s constructor argument"
|
||||
"If not given, defaults to 1e-4."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
tolerance_s = kwargs.pop("tolerance_s")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, tolerance_s=tolerance_s)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
dataset = LeRobotDataset(repo_id, root=root) if not load_from_hf_hub else get_dataset_info(repo_id)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ dependencies = [
|
||||
"gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
"h5py>=3.10.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'",
|
||||
"hydra-core>=1.3.2",
|
||||
"imageio[ffmpeg]>=2.34.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"numba>=0.59.0",
|
||||
@@ -69,7 +70,6 @@ dependencies = [
|
||||
"rerun-sdk>=0.21.0",
|
||||
"termcolor>=2.4.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')",
|
||||
"torchvision>=0.21.0",
|
||||
"wandb>=0.16.3",
|
||||
"zarr>=2.17.0",
|
||||
@@ -103,7 +103,30 @@ requires-poetry = ">=2.1"
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
target-version = "py310"
|
||||
exclude = ["tests/artifacts/**/*.safetensors"]
|
||||
exclude = [
|
||||
"tests/data",
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".git-rewrite",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".pytype",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"]
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.envs.configs import EnvConfig
|
||||
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
|
||||
|
||||
|
||||
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
|
||||
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
|
||||
return f"""
|
||||
from dataclasses import dataclass
|
||||
from lerobot.common.envs.configs import {base_class}
|
||||
|
||||
@{base_class}.register_subclass("{plugin_name}")
|
||||
@dataclass
|
||||
class TestPluginConfig:
|
||||
value: int = 42
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def plugin_dir(tmp_path: Path) -> Generator[Path, None, None]:
|
||||
"""Creates a temporary plugin package structure."""
|
||||
plugin_pkg = tmp_path / "test_plugin"
|
||||
plugin_pkg.mkdir()
|
||||
(plugin_pkg / "__init__.py").touch()
|
||||
|
||||
with open(plugin_pkg / "my_plugin.py", "w") as f:
|
||||
f.write(create_plugin_code())
|
||||
|
||||
# Add tmp_path to Python path so we can import from it
|
||||
sys.path.insert(0, str(tmp_path))
|
||||
yield plugin_pkg
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
def test_parse_plugin_args():
|
||||
cli_args = [
|
||||
"--env.type=test",
|
||||
"--model.discover_packages_path=some.package",
|
||||
"--env.discover_packages_path=other.package",
|
||||
]
|
||||
plugin_args = parse_plugin_args("discover_packages_path", cli_args)
|
||||
assert plugin_args == {
|
||||
"model.discover_packages_path": "some.package",
|
||||
"env.discover_packages_path": "other.package",
|
||||
}
|
||||
|
||||
|
||||
def test_load_plugin_success(plugin_dir: Path):
|
||||
# Import should work and register the plugin with the real EnvConfig
|
||||
load_plugin("test_plugin")
|
||||
|
||||
assert "test_env" in EnvConfig.get_known_choices()
|
||||
plugin_cls = EnvConfig.get_choice_class("test_env")
|
||||
plugin_instance = plugin_cls()
|
||||
assert plugin_instance.value == 42
|
||||
|
||||
|
||||
def test_load_plugin_failure():
|
||||
with pytest.raises(PluginLoadError) as exc_info:
|
||||
load_plugin("nonexistent_plugin")
|
||||
assert "Failed to load plugin 'nonexistent_plugin'" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_wrap_with_plugin(plugin_dir: Path):
|
||||
@dataclass
|
||||
class Config:
|
||||
env: EnvConfig
|
||||
|
||||
@wrap()
|
||||
def dummy_func(cfg: Config):
|
||||
return cfg
|
||||
|
||||
# Test loading plugin via CLI args
|
||||
sys.argv = [
|
||||
"dummy_script.py",
|
||||
"--env.discover_packages_path=test_plugin",
|
||||
"--env.type=test_env",
|
||||
]
|
||||
|
||||
cfg = dummy_func()
|
||||
assert isinstance(cfg, Config)
|
||||
assert isinstance(cfg.env, EnvConfig.get_choice_class("test_env"))
|
||||
assert cfg.env.value == 42
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user