Compare commits
141 Commits
user/miche
...
user/rcade
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aa03a27f0a | ||
|
|
a97c1cb1af | ||
|
|
272a9d9427 | ||
|
|
fc4df91883 | ||
|
|
6366c7f46e | ||
|
|
31429e82d0 | ||
|
|
c79d7ed146 | ||
|
|
b0d05080e0 | ||
|
|
2945dca65b | ||
|
|
691d39aa32 | ||
|
|
56c01a2944 | ||
|
|
49bdcc094c | ||
|
|
81fadfb569 | ||
|
|
6ad84a6561 | ||
|
|
03a44a56f3 | ||
|
|
3b5af7eb38 | ||
|
|
13a2091db3 | ||
|
|
52b7bde587 | ||
|
|
23f6c875b5 | ||
|
|
f56d769dfb | ||
|
|
36b9b60a0e | ||
|
|
93d9bf83c2 | ||
|
|
37da50b573 | ||
|
|
c6ad495176 | ||
|
|
f43e5d07f5 | ||
|
|
9ee8711504 | ||
|
|
6203641710 | ||
|
|
fa7eb860ea | ||
|
|
1f13bda25b | ||
|
|
acae4b49d2 | ||
|
|
c72ad49c9b | ||
|
|
039beb20de | ||
|
|
eda02fade5 | ||
|
|
8546358bc5 | ||
|
|
a91b7c6163 | ||
|
|
58495d5c86 | ||
|
|
beed4763f9 | ||
|
|
4d15861872 | ||
|
|
f3630ad910 | ||
|
|
aed9f4036a | ||
|
|
757ea175d3 | ||
|
|
b69a132737 | ||
|
|
7b159a6b22 | ||
|
|
74270c8c91 | ||
|
|
a6762ec316 | ||
|
|
fde29e0167 | ||
|
|
56e4603d5b | ||
|
|
f6c90ca35f | ||
|
|
c2d6fb6119 | ||
|
|
95a4b59b5b | ||
|
|
16103cb8b7 | ||
|
|
e4ba084e25 | ||
|
|
ac79e8cb36 | ||
|
|
df2cb51364 | ||
|
|
7a342db9c4 | ||
|
|
6b2ec1ed77 | ||
|
|
375abd3020 | ||
|
|
293bdc7f67 | ||
|
|
79d114cc1f | ||
|
|
2650872b76 | ||
|
|
cd1509d805 | ||
|
|
5ea7c78237 | ||
|
|
443a9eec88 | ||
|
|
ab23a4fd27 | ||
|
|
1267c3e955 | ||
|
|
c70b8d0abc | ||
|
|
e1845d4dcc | ||
|
|
e69f0c5059 | ||
|
|
ff84024ee9 | ||
|
|
ee51f54cb5 | ||
|
|
fee5fa5c2e | ||
|
|
8d57093ce7 | ||
|
|
4c22de20a6 | ||
|
|
51e87f6f97 | ||
|
|
df3d2ec5df | ||
|
|
e210d795de | ||
|
|
18ffa4248b | ||
|
|
8bcf81fa24 | ||
|
|
615894d3fb | ||
|
|
450eae310b | ||
|
|
60865e8980 | ||
|
|
0d77be90ee | ||
|
|
0098bd264e | ||
|
|
1aba80d93f | ||
|
|
07570f867f | ||
|
|
b8bdbc1c5b | ||
|
|
7ae8d05326 | ||
|
|
a2a8538ac9 | ||
|
|
fb73cdb9a4 | ||
|
|
9dca233d7e | ||
|
|
c3c0141738 | ||
|
|
c72dc23c43 | ||
|
|
237a484be0 | ||
|
|
6c2cb6e107 | ||
|
|
b46db7ea73 | ||
|
|
ee52b8b782 | ||
|
|
a805458c7e | ||
|
|
e991a31061 | ||
|
|
c4c0a43de7 | ||
|
|
299451af81 | ||
|
|
9ebf8b88ec | ||
|
|
c1232a01e2 | ||
|
|
3b925c3dce | ||
|
|
e46bdb9d30 | ||
|
|
9316cf46ef | ||
|
|
ac3798bd62 | ||
|
|
bce3dc3bfa | ||
|
|
1a51505ec6 | ||
|
|
e7355ba595 | ||
|
|
91e8ce772b | ||
|
|
3a9f964429 | ||
|
|
beacb7e957 | ||
|
|
be64d54bd9 | ||
|
|
354f37aed3 | ||
|
|
d0d8193625 | ||
|
|
7242c57400 | ||
|
|
3ee3739edc | ||
|
|
ad3f112d16 | ||
|
|
50a75ad3fe | ||
|
|
c146ba936f | ||
|
|
110264000f | ||
|
|
9433ac52ec | ||
|
|
da78bbfd16 | ||
|
|
835ab5a81b | ||
|
|
f96773de10 | ||
|
|
cbc51e1341 | ||
|
|
cf633344be | ||
|
|
8bd406e607 | ||
|
|
3ea53124e0 | ||
|
|
7f680886b0 | ||
|
|
6d2bc11365 | ||
|
|
b417cebc4e | ||
|
|
3113038beb | ||
|
|
096824b5ff | ||
|
|
2d75b93ba0 | ||
|
|
21ba4b5263 | ||
|
|
028c17fd48 | ||
|
|
07e113ce21 | ||
|
|
1016a983a1 | ||
|
|
17a1214e25 | ||
|
|
ad115b6c27 |
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
2
.github/PULL_REQUEST_TEMPLATE.md
vendored
@@ -21,7 +21,7 @@ Provide a simple way for the reviewer to try out your changes.
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
pytest -sx tests/test_stuff.py::test_something
|
||||
DATA_DIR=tests/data pytest -sx tests/test_stuff.py::test_something
|
||||
```
|
||||
```bash
|
||||
python lerobot/scripts/train.py --some.option=true
|
||||
|
||||
8
.github/workflows/nightly-tests.yml
vendored
8
.github/workflows/nightly-tests.yml
vendored
@@ -7,8 +7,10 @@ on:
|
||||
schedule:
|
||||
- cron: "0 2 * * *"
|
||||
|
||||
# env:
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
# SLACK_API_TOKEN: ${{ secrets.SLACK_API_TOKEN }}
|
||||
|
||||
jobs:
|
||||
run_all_tests_cpu:
|
||||
name: CPU
|
||||
@@ -28,9 +30,13 @@ jobs:
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Tests
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: pytest -v --cov=./lerobot --disable-warnings tests
|
||||
|
||||
- name: Tests end-to-end
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
run: make test-end-to-end
|
||||
|
||||
|
||||
|
||||
4
.github/workflows/quality.yml
vendored
4
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
63
.github/workflows/test.yml
vendored
63
.github/workflows/test.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
name: Pytest
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -69,6 +70,7 @@ jobs:
|
||||
name: Pytest (minimal install)
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -102,38 +104,39 @@ jobs:
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||
# end-to-end:
|
||||
# name: End-to-end
|
||||
# runs-on: ubuntu-latest
|
||||
# env:
|
||||
# MUJOCO_GL: egl
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# with:
|
||||
# lfs: true # Ensure LFS files are pulled
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DATA_DIR: tests/data
|
||||
MUJOCO_GL: egl
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true # Ensure LFS files are pulled
|
||||
|
||||
# - name: Install apt dependencies
|
||||
# # portaudio19-dev is needed to install pyaudio
|
||||
# run: |
|
||||
# sudo apt-get update && \
|
||||
# sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
- name: Install apt dependencies
|
||||
# portaudio19-dev is needed to install pyaudio
|
||||
run: |
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y libegl1-mesa-dev portaudio19-dev
|
||||
|
||||
# - name: Install poetry
|
||||
# run: |
|
||||
# pipx install poetry && poetry config virtualenvs.in-project true
|
||||
# echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
- name: Install poetry
|
||||
run: |
|
||||
pipx install poetry && poetry config virtualenvs.in-project true
|
||||
echo "${{ github.workspace }}/.venv/bin" >> $GITHUB_PATH
|
||||
|
||||
# - name: Set up Python 3.10
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: "3.10"
|
||||
# cache: "poetry"
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
cache: "poetry"
|
||||
|
||||
# - name: Install poetry dependencies
|
||||
# run: |
|
||||
# poetry install --all-extras
|
||||
- name: Install poetry dependencies
|
||||
run: |
|
||||
poetry install --all-extras
|
||||
|
||||
# - name: Test end-to-end
|
||||
# run: |
|
||||
# make test-end-to-end \
|
||||
# && rm -rf outputs
|
||||
- name: Test end-to-end
|
||||
run: |
|
||||
make test-end-to-end \
|
||||
&& rm -rf outputs
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
python: python3.10
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
rev: v4.6.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: debug-statements
|
||||
@@ -14,11 +14,11 @@ repos:
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.19.0
|
||||
rev: v3.16.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.8.2
|
||||
rev: v0.5.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
@@ -32,6 +32,6 @@ repos:
|
||||
- "--check"
|
||||
- "--no-update"
|
||||
- repo: https://github.com/gitleaks/gitleaks
|
||||
rev: v8.21.2
|
||||
rev: v8.18.4
|
||||
hooks:
|
||||
- id: gitleaks
|
||||
|
||||
@@ -267,7 +267,7 @@ We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
python -m pytest -sv ./tests
|
||||
DATA_DIR="tests/data" python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
|
||||
|
||||
16
README.md
16
README.md
@@ -68,7 +68,7 @@
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
@@ -153,12 +153,10 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
or from a dataset in a local folder with the `root` option and the `--local-files-only` (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
or from a dataset in a local folder with the root `DATA_DIR` environment variable (in the following case the dataset will be searched for in `./my_local_data_dir/lerobot/pusht`)
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
DATA_DIR='./my_local_data_dir' python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id lerobot/pusht \
|
||||
--root ./my_local_data_dir \
|
||||
--local-files-only 1 \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
@@ -210,10 +208,12 @@ dataset attributes:
|
||||
|
||||
A `LeRobotDataset` is serialised using several widespread file formats for each of its parts, namely:
|
||||
- hf_dataset stored using Hugging Face datasets library serialization to parquet
|
||||
- videos are stored in mp4 format to save space
|
||||
- metadata are stored in plain json/jsonl files
|
||||
- videos are stored in mp4 format to save space or png files
|
||||
- episode_data_index saved using `safetensor` tensor serialization format
|
||||
- stats saved using `safetensor` tensor serialization format
|
||||
- info are saved using JSON
|
||||
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can use the `local_files_only` argument and specify its location with the `root` argument if it's not in the default `~/.cache/huggingface/lerobot` location.
|
||||
Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work on a local dataset, you can set the `DATA_DIR` environment variable to your root dataset folder as illustrated in the above section on dataset visualization.
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ How to decode videos?
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
|
||||
@@ -1,31 +1,25 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
|
||||
|
||||
|
||||
## A. Source the parts
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
@@ -42,30 +36,23 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## C. Configure the motors
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find ports
|
||||
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -77,6 +64,7 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -89,20 +77,13 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
@@ -113,7 +94,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -127,25 +108,23 @@ python lerobot/scripts/configure_motor.py \
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## D. Assemble the arms
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## E. Calibrate
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
@@ -160,8 +139,8 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
@@ -174,7 +153,7 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
@@ -186,14 +165,14 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
@@ -213,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--tags so100 tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -222,7 +202,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## H. Visualize a dataset
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
@@ -232,25 +212,27 @@ echo ${HF_USER}/so100_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/so100_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/so100_test \
|
||||
policy=act_so100_real \
|
||||
env=so100_real \
|
||||
@@ -266,16 +248,18 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=so100_real`. This loads configurations from [`lerobot/configs/env/so100_real.yaml`](../lerobot/configs/env/so100_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_so100_test \
|
||||
--tags so100 tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
@@ -289,7 +273,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ Record 2 episodes and upload your dataset to the hub:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--tags moss tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -211,6 +212,7 @@ echo ${HF_USER}/moss_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
@@ -218,9 +220,10 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -229,7 +232,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/moss_test \
|
||||
policy=act_moss_real \
|
||||
env=moss_real \
|
||||
@@ -245,6 +248,7 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
@@ -255,6 +259,7 @@ You can use the `record` function from [`lerobot/scripts/control_robot.py`](../l
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||
--tags moss tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
@@ -10,10 +10,10 @@ from torchvision.transforms import ToPILImage, v2
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
dataset_repo_id = "lerobot/aloha_static_screw_driver"
|
||||
dataset_repo_id = "lerobot/aloha_static_tape"
|
||||
|
||||
# Create a LeRobotDataset with no transformations
|
||||
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
|
||||
dataset = LeRobotDataset(dataset_repo_id)
|
||||
# This is equivalent to `dataset = LeRobotDataset(dataset_repo_id, image_transforms=None)`
|
||||
|
||||
# Get the index of the first observation in the first episode
|
||||
@@ -28,13 +28,12 @@ transforms = v2.Compose(
|
||||
[
|
||||
v2.ColorJitter(brightness=(0.5, 1.5)),
|
||||
v2.ColorJitter(contrast=(0.5, 1.5)),
|
||||
v2.ColorJitter(hue=(-0.1, 0.1)),
|
||||
v2.RandomAdjustSharpness(sharpness_factor=2, p=1),
|
||||
]
|
||||
)
|
||||
|
||||
# Create another LeRobotDataset with the defined transformations
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms)
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
|
||||
@@ -29,7 +29,7 @@ For a visual walkthrough of the assembly process, you can refer to [this video t
|
||||
|
||||
## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1
|
||||
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed).
|
||||
First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands.
|
||||
|
||||
Using `pip`:
|
||||
```bash
|
||||
@@ -778,6 +778,7 @@ Now run this to record 2 episodes:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--tags tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -786,7 +787,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--num-episodes 2
|
||||
```
|
||||
|
||||
This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
This will write your dataset locally to `{root}/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example).
|
||||
|
||||
You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
@@ -839,6 +840,7 @@ In the coming months, we plan to release a foundational model for robotics. We a
|
||||
You can visualize your dataset by running:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test
|
||||
```
|
||||
|
||||
@@ -856,6 +858,7 @@ To replay the first episode of the dataset you just recorded, run the following
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -868,7 +871,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/koch_test \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
@@ -915,6 +918,7 @@ env:
|
||||
It should match your dataset (e.g. `fps: 30`) and your robot (e.g. `state_dim: 6` and `action_dim: 6`). We are still working on simplifying this in future versions of `lerobot`.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you could use `device=mps` to train on Apple silicon.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md)
|
||||
|
||||
@@ -987,6 +991,7 @@ To this end, you can use the `record` function from [`lerobot/scripts/control_ro
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/koch.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test \
|
||||
--tags tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
@@ -1005,6 +1010,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_koch_test
|
||||
```
|
||||
|
||||
|
||||
@@ -128,6 +128,7 @@ Record one episode:
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--tags stretch tutorial \
|
||||
--warmup-time-s 3 \
|
||||
@@ -145,6 +146,7 @@ Now try to replay this episode (make sure the robot's initial position is the sa
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/stretch.yaml \
|
||||
--fps 20 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/stretch_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
@@ -56,7 +56,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-overrides max_relative_target=5
|
||||
```
|
||||
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
By adding `--robot-overrides max_relative_target=5`, we override the default value for `max_relative_target` defined in `lerobot/configs/robot/aloha.yaml`. It is expected to be `5` to limit the magnitude of the movement for more safety, but the teloperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot-overrides max_relative_target=null` to the command line:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
@@ -84,6 +84,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--tags aloha tutorial \
|
||||
--warmup-time-s 5 \
|
||||
@@ -103,6 +104,7 @@ echo ${HF_USER}/aloha_test
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test
|
||||
```
|
||||
|
||||
@@ -117,6 +119,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/aloha_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -125,7 +128,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/aloha_test \
|
||||
policy=act_aloha_real \
|
||||
env=aloha_real \
|
||||
@@ -141,6 +144,7 @@ Let's explain it:
|
||||
3. We provided an environment as argument with `env=aloha_real`. This loads configurations from [`lerobot/configs/env/aloha_real.yaml`](../lerobot/configs/env/aloha_real.yaml). Note: this yaml defines 18 dimensions for the `state_dim` and `action_dim`, corresponding to 18 motors, not 14 motors as used in previous Aloha work. This is because, we include the `shoulder_shadow` and `elbow_shadow` motors for simplicity.
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`.
|
||||
|
||||
@@ -152,6 +156,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/aloha.yaml \
|
||||
--robot-overrides max_relative_target=null \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_aloha_test \
|
||||
--tags aloha tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
|
||||
@@ -63,7 +63,7 @@ def build_features(mode: str) -> dict:
|
||||
return features
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path: Path):
|
||||
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
|
||||
@@ -28,7 +28,7 @@ def safe_stop_image_writer(func):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
dataset = kwargs.get("dataset")
|
||||
dataset = kwargs.get("dataset", None)
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
|
||||
@@ -291,14 +291,14 @@ class LeRobotDatasetMetadata:
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks."
|
||||
)
|
||||
elif features is None:
|
||||
elif robot_type is None or features is None:
|
||||
raise ValueError(
|
||||
"Dataset features must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
|
||||
639
lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml
Normal file
639
lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml
Normal file
@@ -0,0 +1,639 @@
|
||||
OPENX_DATASET_CONFIGS:
|
||||
fractal20220817_data:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 3
|
||||
|
||||
kuka:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- clip_function_input/base_pose_tool_reached
|
||||
- gripper_closed
|
||||
fps: 10
|
||||
|
||||
bridge_openx:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- EEF_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
taco_play:
|
||||
image_obs_keys:
|
||||
- rgb_static
|
||||
- rgb_gripper
|
||||
depth_obs_keys:
|
||||
- depth_static
|
||||
- depth_gripper
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 15
|
||||
|
||||
jaco_play:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state_eef
|
||||
- state_gripper
|
||||
fps: 10
|
||||
|
||||
berkeley_cable_routing:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- top_image
|
||||
- wrist45_image
|
||||
- wrist225_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
roboturk:
|
||||
image_obs_keys:
|
||||
- front_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
nyu_door_opening_surprising_effectiveness:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
viola:
|
||||
image_obs_keys:
|
||||
- agentview_rgb
|
||||
- eye_in_hand_rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_states
|
||||
- gripper_states
|
||||
fps: 20
|
||||
|
||||
berkeley_autolab_ur5:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- image_with_depth
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
toto:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
|
||||
language_table:
|
||||
image_obs_keys:
|
||||
- rgb
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- effector_translation
|
||||
fps: 10
|
||||
|
||||
columbia_cairlab_pusht_real:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- robot_state
|
||||
fps: 10
|
||||
|
||||
stanford_kuka_multimodal_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- depth_image
|
||||
state_obs_keys:
|
||||
- ee_position
|
||||
- ee_orientation
|
||||
fps: 20
|
||||
|
||||
nyu_rot_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
io_ai_tech:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_fisheye
|
||||
- image_left_side
|
||||
- image_right_side
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 3
|
||||
|
||||
stanford_hydra_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
austin_buds_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
nyu_franka_play_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_additional_view
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- depth_additional_view
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
fps: 3
|
||||
|
||||
maniskill_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- depth
|
||||
- wrist_depth
|
||||
state_obs_keys:
|
||||
- tcp_pose
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
furniture_bench_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_franka_exploration_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- highres_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
ucsd_kitchen_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
fps: 2
|
||||
|
||||
ucsd_pick_and_place_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 3
|
||||
|
||||
spoc:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image_manipulation
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 3
|
||||
|
||||
austin_sailor_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
austin_sirius_dataset_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
bc_z:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- present/xyz
|
||||
- present/axis_angle
|
||||
- present/sensed_close
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_opening_fridge_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_pick_and_place_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image2
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- end_effector_pose
|
||||
fps: 10
|
||||
|
||||
utokyo_xarm_bimanual_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- pose_r
|
||||
fps: 10
|
||||
|
||||
robo_net:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- image1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 1
|
||||
|
||||
robo_set:
|
||||
image_obs_keys:
|
||||
- image_left
|
||||
- image_right
|
||||
- image_wrist
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- state_velocity
|
||||
fps: 5
|
||||
|
||||
berkeley_mvp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- gripper
|
||||
- pose
|
||||
- joint_pos
|
||||
fps: 5
|
||||
|
||||
berkeley_rpt_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- hand_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_pos
|
||||
- gripper
|
||||
fps: 30
|
||||
|
||||
kaist_nonprehensile_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
stanford_mask_vit_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
|
||||
tokyo_u_lsmo_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_pour_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_sara_grid_clamp_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
dlr_edan_shared_control_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
asu_table_top_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 12.5
|
||||
|
||||
stanford_robocook_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 5
|
||||
|
||||
imperialcollege_sawyer_wrist_cam:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
iamlab_cmu_pickup_insert_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 20
|
||||
|
||||
uiuc_d3field:
|
||||
image_obs_keys:
|
||||
- image_1
|
||||
- image_2
|
||||
depth_obs_keys:
|
||||
- depth_1
|
||||
- depth_2
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 1
|
||||
|
||||
utaustin_mutex:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
berkeley_fanuc_manipulation:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- joint_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
cmu_playing_with_food:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- finger_vision_1
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 10
|
||||
|
||||
cmu_play_fusion:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 5
|
||||
|
||||
cmu_stretch:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- eef_state
|
||||
- gripper_state
|
||||
fps: 10
|
||||
|
||||
berkeley_gnm_recon:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 3
|
||||
|
||||
berkeley_gnm_cory_hall:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 5
|
||||
|
||||
berkeley_gnm_sac_son:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
- position
|
||||
- yaw
|
||||
fps: 10
|
||||
|
||||
droid:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
droid_100:
|
||||
image_obs_keys:
|
||||
- exterior_image_1_left
|
||||
- exterior_image_2_left
|
||||
- wrist_image_left
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 15
|
||||
|
||||
fmb:
|
||||
image_obs_keys:
|
||||
- image_side_1
|
||||
- image_side_2
|
||||
- image_wrist_1
|
||||
- image_wrist_2
|
||||
depth_obs_keys:
|
||||
- image_side_1_depth
|
||||
- image_side_2_depth
|
||||
- image_wrist_1_depth
|
||||
- image_wrist_2_depth
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 10
|
||||
|
||||
dobbe:
|
||||
image_obs_keys:
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- proprio
|
||||
fps: 3.75
|
||||
|
||||
usc_cloth_sim_converted_externally_to_rlds:
|
||||
image_obs_keys:
|
||||
- image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- null
|
||||
fps: 10
|
||||
|
||||
plex_robosuite:
|
||||
image_obs_keys:
|
||||
- image
|
||||
- wrist_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 20
|
||||
|
||||
conq_hose_manipulation:
|
||||
image_obs_keys:
|
||||
- frontleft_fisheye_image
|
||||
- frontright_fisheye_image
|
||||
- hand_color_image
|
||||
depth_obs_keys:
|
||||
- null
|
||||
state_obs_keys:
|
||||
- state
|
||||
fps: 30
|
||||
106
lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py
Normal file
106
lerobot/common/datasets/push_dataset_to_hub/openx/data_utils.py
Normal file
@@ -0,0 +1,106 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the Licens e.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
Octo: https://github.com/octo-models/octo/blob/main/octo/data/utils/data_utils.py
|
||||
|
||||
data_utils.py
|
||||
|
||||
Additional utils for data processing.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def binarize_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts gripper actions from continuous to binary values (0 and 1).
|
||||
|
||||
We exploit that fact that most of the time, the gripper is fully open (near 1.0) or fully closed (near 0.0). As it
|
||||
transitions between the two, it sometimes passes through a few intermediate values. We relabel those intermediate
|
||||
values based on the state that is reached _after_ those intermediate values.
|
||||
|
||||
In the edge case that the trajectory ends with an intermediate value, we give up on binarizing and relabel that
|
||||
chunk of intermediate values as the last action in the trajectory.
|
||||
|
||||
The `scan_fn` implements the following logic:
|
||||
new_actions = np.empty_like(actions)
|
||||
carry = actions[-1]
|
||||
for i in reversed(range(actions.shape[0])):
|
||||
if in_between_mask[i]:
|
||||
carry = carry
|
||||
else:
|
||||
carry = float(open_mask[i])
|
||||
new_actions[i] = carry
|
||||
"""
|
||||
open_mask, closed_mask = actions > 0.95, actions < 0.05
|
||||
in_between_mask = tf.logical_not(tf.logical_or(open_mask, closed_mask))
|
||||
is_open_float = tf.cast(open_mask, tf.float32)
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(in_between_mask[i], lambda: tf.cast(carry, tf.float32), lambda: is_open_float[i])
|
||||
|
||||
return tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), actions[-1], reverse=True)
|
||||
|
||||
|
||||
def invert_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
return 1 - actions
|
||||
|
||||
|
||||
def rel2abs_gripper_actions(actions: tf.Tensor) -> tf.Tensor:
|
||||
"""
|
||||
Converts relative gripper actions (+1 for closing, -1 for opening) to absolute actions (0 = closed; 1 = open).
|
||||
|
||||
Assumes that the first relative gripper is not redundant (i.e. close when already closed)!
|
||||
"""
|
||||
# Note =>> -1 for closing, 1 for opening, 0 for no change
|
||||
opening_mask, closing_mask = actions < -0.1, actions > 0.1
|
||||
thresholded_actions = tf.where(opening_mask, 1, tf.where(closing_mask, -1, 0))
|
||||
|
||||
def scan_fn(carry, i):
|
||||
return tf.cond(thresholded_actions[i] == 0, lambda: carry, lambda: thresholded_actions[i])
|
||||
|
||||
# If no relative grasp, assumes open for whole trajectory
|
||||
start = -1 * thresholded_actions[tf.argmax(thresholded_actions != 0, axis=0)]
|
||||
start = tf.cond(start == 0, lambda: 1, lambda: start)
|
||||
|
||||
# Note =>> -1 for closed, 1 for open
|
||||
new_actions = tf.scan(scan_fn, tf.range(tf.shape(actions)[0]), start)
|
||||
new_actions = tf.cast(new_actions, tf.float32) / 2 + 0.5
|
||||
|
||||
return new_actions
|
||||
|
||||
|
||||
# === Bridge-V2 =>> Dataset-Specific Transform ===
|
||||
def relabel_bridge_actions(traj: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Relabels actions to use reached proprioceptive state; discards last timestep (no-action)."""
|
||||
movement_actions = traj["observation"]["state"][1:, :6] - traj["observation"]["state"][:-1, :6]
|
||||
traj_truncated = tf.nest.map_structure(lambda x: x[:-1], traj)
|
||||
traj_truncated["action"] = tf.concat([movement_actions, traj["action"][:-1, -1:]], axis=1)
|
||||
|
||||
return traj_truncated
|
||||
|
||||
|
||||
# === RLDS Dataset Initialization Utilities ===
|
||||
def pprint_data_mixture(dataset_kwargs_list: List[Dict[str, Any]], dataset_weights: List[int]) -> None:
|
||||
print("\n######################################################################################")
|
||||
print(f"# Loading the following {len(dataset_kwargs_list)} datasets (incl. sampling weight):{'': >24} #")
|
||||
for dataset_kwargs, weight in zip(dataset_kwargs_list, dataset_weights, strict=False):
|
||||
pad = 80 - len(dataset_kwargs["name"])
|
||||
print(f"# {dataset_kwargs['name']}: {weight:=>{pad}f} #")
|
||||
print("######################################################################################\n")
|
||||
200
lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py
Normal file
200
lerobot/common/datasets/push_dataset_to_hub/openx/droid_utils.py
Normal file
@@ -0,0 +1,200 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
|
||||
Episode transforms for DROID dataset.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
import tensorflow_graphics.geometry.transformation as tfg
|
||||
|
||||
|
||||
def rmat_to_euler(rot_mat):
|
||||
return tfg.euler.from_rotation_matrix(rot_mat)
|
||||
|
||||
|
||||
def euler_to_rmat(euler):
|
||||
return tfg.rotation_matrix_3d.from_euler(euler)
|
||||
|
||||
|
||||
def invert_rmat(rot_mat):
|
||||
return tfg.rotation_matrix_3d.inverse(rot_mat)
|
||||
|
||||
|
||||
def rotmat_to_rot6d(mat):
|
||||
"""
|
||||
Converts rotation matrix to R6 rotation representation (first two rows in rotation matrix).
|
||||
Args:
|
||||
mat: rotation matrix
|
||||
|
||||
Returns: 6d vector (first two rows of rotation matrix)
|
||||
|
||||
"""
|
||||
r6 = mat[..., :2, :]
|
||||
r6_0, r6_1 = r6[..., 0, :], r6[..., 1, :]
|
||||
r6_flat = tf.concat([r6_0, r6_1], axis=-1)
|
||||
return r6_flat
|
||||
|
||||
|
||||
def velocity_act_to_wrist_frame(velocity, wrist_in_robot_frame):
|
||||
"""
|
||||
Translates velocity actions (translation + rotation) from base frame of the robot to wrist frame.
|
||||
Args:
|
||||
velocity: 6d velocity action (3 x translation, 3 x rotation)
|
||||
wrist_in_robot_frame: 6d pose of the end-effector in robot base frame
|
||||
|
||||
Returns: 9d velocity action in robot wrist frame (3 x translation, 6 x rotation as R6)
|
||||
|
||||
"""
|
||||
r_frame = euler_to_rmat(wrist_in_robot_frame[:, 3:6])
|
||||
r_frame_inv = invert_rmat(r_frame)
|
||||
|
||||
# world to wrist: dT_pi = R^-1 dT_rbt
|
||||
vel_t = (r_frame_inv @ velocity[:, :3][..., None])[..., 0]
|
||||
|
||||
# world to wrist: dR_pi = R^-1 dR_rbt R
|
||||
dr_ = euler_to_rmat(velocity[:, 3:6])
|
||||
dr_ = r_frame_inv @ (dr_ @ r_frame)
|
||||
dr_r6 = rotmat_to_rot6d(dr_)
|
||||
return tf.concat([vel_t, dr_r6], axis=-1)
|
||||
|
||||
|
||||
def rand_swap_exterior_images(img1, img2):
|
||||
"""
|
||||
Randomly swaps the two exterior images (for training with single exterior input).
|
||||
"""
|
||||
return tf.cond(tf.random.uniform(shape=[]) > 0.5, lambda: (img1, img2), lambda: (img2, img1))
|
||||
|
||||
|
||||
def droid_baseact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_wristact_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *wrist* frame of the robot.
|
||||
"""
|
||||
wrist_act = velocity_act_to_wrist_frame(
|
||||
trajectory["action_dict"]["cartesian_velocity"], trajectory["observation"]["cartesian_position"]
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
wrist_act,
|
||||
trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["exterior_image_1_left"], trajectory["observation"]["exterior_image_2_left"] = (
|
||||
rand_swap_exterior_images(
|
||||
trajectory["observation"]["exterior_image_1_left"],
|
||||
trajectory["observation"]["exterior_image_2_left"],
|
||||
)
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def droid_finetuning_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
DROID dataset transformation for actions expressed in *base* frame of the robot.
|
||||
"""
|
||||
dt = trajectory["action_dict"]["cartesian_velocity"][:, :3]
|
||||
dr_ = trajectory["action_dict"]["cartesian_velocity"][:, 3:6]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
dt,
|
||||
dr_,
|
||||
1 - trajectory["action_dict"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["cartesian_position"],
|
||||
trajectory["observation"]["gripper_position"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def zero_action_filter(traj: Dict) -> bool:
|
||||
"""
|
||||
Filters transitions whose actions are all-0 (only relative actions, no gripper action).
|
||||
Note: this filter is applied *after* action normalization, so need to compare to "normalized 0".
|
||||
"""
|
||||
droid_q01 = tf.convert_to_tensor(
|
||||
[
|
||||
-0.7776297926902771,
|
||||
-0.5803514122962952,
|
||||
-0.5795090794563293,
|
||||
-0.6464047729969025,
|
||||
-0.7041108310222626,
|
||||
-0.8895104378461838,
|
||||
]
|
||||
)
|
||||
droid_q99 = tf.convert_to_tensor(
|
||||
[
|
||||
0.7597932070493698,
|
||||
0.5726242214441299,
|
||||
0.7351000607013702,
|
||||
0.6705610305070877,
|
||||
0.6464948207139969,
|
||||
0.8897542208433151,
|
||||
]
|
||||
)
|
||||
droid_norm_0_act = (
|
||||
2 * (tf.zeros_like(traj["action"][:, :6]) - droid_q01) / (droid_q99 - droid_q01 + 1e-8) - 1
|
||||
)
|
||||
|
||||
return tf.reduce_any(tf.math.abs(traj["action"][:, :6] - droid_norm_0_act) > 1e-5)
|
||||
859
lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py
Normal file
859
lerobot/common/datasets/push_dataset_to_hub/openx/transforms.py
Normal file
@@ -0,0 +1,859 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
NOTE(YL): Adapted from:
|
||||
OpenVLA: https://github.com/openvla/openvla
|
||||
Octo: https://github.com/octo-models/octo
|
||||
|
||||
transforms.py
|
||||
|
||||
Defines a registry of per-dataset standardization transforms for each dataset in Open-X Embodiment.
|
||||
|
||||
Transforms adopt the following structure:
|
||||
Input: Dictionary of *batched* features (i.e., has leading time dimension)
|
||||
Output: Dictionary `step` =>> {
|
||||
"observation": {
|
||||
<image_keys, depth_image_keys>
|
||||
State (in chosen state representation)
|
||||
},
|
||||
"action": Action (in chosen action representation),
|
||||
"language_instruction": str
|
||||
}
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.data_utils import (
|
||||
binarize_gripper_actions,
|
||||
invert_gripper_actions,
|
||||
rel2abs_gripper_actions,
|
||||
relabel_bridge_actions,
|
||||
)
|
||||
|
||||
|
||||
def droid_baseact_transform_fn():
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.droid_utils import droid_baseact_transform
|
||||
|
||||
return droid_baseact_transform
|
||||
|
||||
|
||||
def bridge_openx_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to version of Bridge V2 in Open X-Embodiment mixture.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key in ["observation", "action"]:
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Applies to original version of Bridge V2 from the official project website.
|
||||
|
||||
Note =>> In original Bridge V2 dataset, the first timestep has an all-zero action, so we remove it!
|
||||
"""
|
||||
for key in trajectory:
|
||||
if key == "traj_metadata":
|
||||
continue
|
||||
elif key == "observation":
|
||||
for key2 in trajectory[key]:
|
||||
trajectory[key][key2] = trajectory[key][key2][1:]
|
||||
else:
|
||||
trajectory[key] = trajectory[key][1:]
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory = relabel_bridge_actions(trajectory)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ppgm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
[
|
||||
trajectory["action"][:, :6],
|
||||
binarize_gripper_actions(trajectory["action"][:, -1])[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
trajectory["observation"]["EEF_state"] = trajectory["observation"]["cartesian_position"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["gripper_position"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def rt1_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
# decode compressed state
|
||||
eef_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
||||
compression_type="ZLIB",
|
||||
)
|
||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
||||
gripper_value = tf.io.decode_compressed(
|
||||
trajectory["observation"]["gripper_closed"], compression_type="ZLIB"
|
||||
)
|
||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def taco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["robot_obs"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["robot_obs"][:, 7:8]
|
||||
trajectory["action"] = trajectory["action"]["rel_actions_world"]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.clip_by_value(trajectory["action"][:, -1:], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def jaco_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state_eef"] = trajectory["observation"]["end_effector_cartesian_pos"][:, :6]
|
||||
trajectory["observation"]["state_gripper"] = trajectory["observation"]["end_effector_cartesian_pos"][
|
||||
:, -1:
|
||||
]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"]),
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_cable_routing_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.zeros_like(trajectory["action"]["world_vector"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def roboturk_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert absolute gripper action, +1 = open, 0 = close
|
||||
gripper_action = invert_gripper_actions(
|
||||
tf.clip_by_value(trajectory["action"]["gripper_closedness_action"], 0, 1)
|
||||
)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
trajectory["language_embedding"] = trajectory["observation"]["natural_language_embedding"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_door_opening_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, 0]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def viola_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# make gripper action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"][:, None]
|
||||
gripper_action = tf.clip_by_value(gripper_action, 0, 1)
|
||||
gripper_action = invert_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_autolab_ur5_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["robot_state"][:, 6:14]
|
||||
|
||||
# make gripper action absolute action, +1 = open, 0 = close
|
||||
gripper_action = trajectory["action"]["gripper_closedness_action"]
|
||||
gripper_action = rel2abs_gripper_actions(gripper_action)
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
gripper_action[:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def toto_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
tf.cast(trajectory["action"]["open_gripper"][:, None], tf.float32),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def language_table_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# default to "open" gripper
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.ones_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# decode language instruction
|
||||
instruction_bytes = trajectory["observation"]["instruction"]
|
||||
instruction_encoded = tf.strings.unicode_encode(instruction_bytes, output_encoding="UTF-8")
|
||||
# Remove trailing padding --> convert RaggedTensor to regular Tensor.
|
||||
trajectory["language_instruction"] = tf.strings.split(instruction_encoded, "\x00")[:, :1].to_tensor()[
|
||||
:, 0
|
||||
]
|
||||
return trajectory
|
||||
|
||||
|
||||
def pusht_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["world_vector"],
|
||||
trajectory["action"]["rotation_delta"],
|
||||
trajectory["action"]["gripper_closedness_action"][:, None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_kuka_multimodal_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth_image"] = trajectory["observation"]["depth_image"][..., 0]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_rot_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][..., :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :7]
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_hydra_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
trajectory["observation"]["state"][:, 7:10],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -3:-2]
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_buds_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def nyu_franka_play_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["depth"] = tf.cast(trajectory["observation"]["depth"][..., 0], tf.float32)
|
||||
trajectory["observation"]["depth_additional_view"] = tf.cast(
|
||||
trajectory["observation"]["depth_additional_view"][..., 0], tf.float32
|
||||
)
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, -6:]
|
||||
|
||||
# clip gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, -8:-2],
|
||||
tf.clip_by_value(trajectory["action"][:, -2:-1], 0, 1),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def maniskill_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][..., 7:8]
|
||||
return trajectory
|
||||
|
||||
|
||||
def furniture_bench_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :7],
|
||||
trajectory["observation"]["state"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_franka_exploration_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_kitchen_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def ucsd_pick_place_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tf.zeros_like(trajectory["action"][:, :3]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sailor_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def austin_sirius_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def bc_z_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"]["future/xyz_residual"][:, :3],
|
||||
trajectory["action"]["future/axis_angle_residual"][:, :3],
|
||||
invert_gripper_actions(tf.cast(trajectory["action"]["future/target_close"][:, :1], tf.float32)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_opening_fridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_pr2_tabletop_manipulation_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def utokyo_xarm_bimanual_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., -7:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_net_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_mvp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
trajectory["observation"]["state"] = tf.concat((
|
||||
tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32),
|
||||
trajectory["observation"]["pose"],
|
||||
trajectory["observation"]["joint_pos"],),
|
||||
axis=-1,)
|
||||
"""
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_rpt_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["gripper"] = tf.cast(trajectory["observation"]["gripper"][:, None], tf.float32)
|
||||
return trajectory
|
||||
|
||||
|
||||
def kaist_nonprehensible_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, -7:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def stanford_mask_vit_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["end_effector_pose"][:, :4],
|
||||
tf.zeros_like(trajectory["observation"]["end_effector_pose"][:, :2]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["end_effector_pose"][:, -1:]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :4],
|
||||
tf.zeros_like(trajectory["action"][:, :2]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def tokyo_lsmo_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_sara_grid_clamp_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :6]
|
||||
return trajectory
|
||||
|
||||
|
||||
def dlr_edan_shared_control_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# invert gripper action, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(trajectory["action"][:, -1:]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def asu_table_top_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["ground_truth_states"]["EE"]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robocook_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
return trajectory
|
||||
|
||||
|
||||
def imperial_wristcam_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def iamlab_pick_insert_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :7]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 7:8]
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, 7:8],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def uiuc_d3field_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def utaustin_mutex_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = trajectory["observation"]["state"][:, :8]
|
||||
|
||||
# invert gripper action + clip, +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :6],
|
||||
invert_gripper_actions(tf.clip_by_value(trajectory["action"][:, -1:], 0, 1)),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def berkeley_fanuc_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["joint_state"] = trajectory["observation"]["state"][:, :6]
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, 6:7]
|
||||
|
||||
# dataset does not store gripper actions, so use gripper state info, invert so +1 = open, 0 = close
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
invert_gripper_actions(trajectory["observation"]["gripper_state"]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_playing_with_food_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
import tensorflow_graphics.geometry.transformation as tft
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
tft.euler.from_quaternion(trajectory["action"][:, 3:7]),
|
||||
trajectory["action"][:, -1:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def playfusion_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :3],
|
||||
trajectory["action"][:, -4:],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def cmu_stretch_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["eef_state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["state"][:, :3],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["observation"]["gripper_state"] = trajectory["observation"]["state"][:, -1:]
|
||||
trajectory["action"] = trajectory["action"][..., :-1]
|
||||
return trajectory
|
||||
|
||||
|
||||
def gnm_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
trajectory["observation"]["state"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["position"],
|
||||
tf.zeros_like(trajectory["observation"]["state"][:, :3]),
|
||||
trajectory["observation"]["yaw"],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"],
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"]),
|
||||
tf.zeros_like(trajectory["action"][:, :1]),
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def fmb_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = tf.concat(
|
||||
(
|
||||
trajectory["observation"]["eef_pose"],
|
||||
trajectory["observation"]["state_gripper_pose"][..., None],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def dobbe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# every input feature is batched, ie has leading batch dimension
|
||||
trajectory["observation"]["proprio"] = trajectory["observation"]["state"]
|
||||
return trajectory
|
||||
|
||||
|
||||
def robo_set_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# gripper action is in -1...1 --> clip to 0...1, flip
|
||||
gripper_action = trajectory["action"][:, -1:]
|
||||
gripper_action = invert_gripper_actions(tf.clip_by_value(gripper_action, 0, 1))
|
||||
|
||||
trajectory["action"] = tf.concat(
|
||||
(
|
||||
trajectory["action"][:, :7],
|
||||
gripper_action,
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
return trajectory
|
||||
|
||||
|
||||
def identity_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return trajectory
|
||||
|
||||
|
||||
# === Registry ===
|
||||
OPENX_STANDARDIZATION_TRANSFORMS = {
|
||||
"bridge_openx": bridge_openx_dataset_transform,
|
||||
"bridge_orig": bridge_orig_dataset_transform,
|
||||
"bridge_dataset": bridge_orig_dataset_transform,
|
||||
"ppgm": ppgm_dataset_transform,
|
||||
"ppgm_static": ppgm_dataset_transform,
|
||||
"ppgm_wrist": ppgm_dataset_transform,
|
||||
"fractal20220817_data": rt1_dataset_transform,
|
||||
"kuka": kuka_dataset_transform,
|
||||
"taco_play": taco_play_dataset_transform,
|
||||
"jaco_play": jaco_play_dataset_transform,
|
||||
"berkeley_cable_routing": berkeley_cable_routing_dataset_transform,
|
||||
"roboturk": roboturk_dataset_transform,
|
||||
"nyu_door_opening_surprising_effectiveness": nyu_door_opening_dataset_transform,
|
||||
"viola": viola_dataset_transform,
|
||||
"berkeley_autolab_ur5": berkeley_autolab_ur5_dataset_transform,
|
||||
"toto": toto_dataset_transform,
|
||||
"language_table": language_table_dataset_transform,
|
||||
"columbia_cairlab_pusht_real": pusht_dataset_transform,
|
||||
"stanford_kuka_multimodal_dataset_converted_externally_to_rlds": stanford_kuka_multimodal_dataset_transform,
|
||||
"nyu_rot_dataset_converted_externally_to_rlds": nyu_rot_dataset_transform,
|
||||
"stanford_hydra_dataset_converted_externally_to_rlds": stanford_hydra_dataset_transform,
|
||||
"austin_buds_dataset_converted_externally_to_rlds": austin_buds_dataset_transform,
|
||||
"nyu_franka_play_dataset_converted_externally_to_rlds": nyu_franka_play_dataset_transform,
|
||||
"maniskill_dataset_converted_externally_to_rlds": maniskill_dataset_transform,
|
||||
"furniture_bench_dataset_converted_externally_to_rlds": furniture_bench_dataset_transform,
|
||||
"cmu_franka_exploration_dataset_converted_externally_to_rlds": cmu_franka_exploration_dataset_transform,
|
||||
"ucsd_kitchen_dataset_converted_externally_to_rlds": ucsd_kitchen_dataset_transform,
|
||||
"ucsd_pick_and_place_dataset_converted_externally_to_rlds": ucsd_pick_place_dataset_transform,
|
||||
"austin_sailor_dataset_converted_externally_to_rlds": austin_sailor_dataset_transform,
|
||||
"austin_sirius_dataset_converted_externally_to_rlds": austin_sirius_dataset_transform,
|
||||
"bc_z": bc_z_dataset_transform,
|
||||
"utokyo_pr2_opening_fridge_converted_externally_to_rlds": tokyo_pr2_opening_fridge_dataset_transform,
|
||||
"utokyo_pr2_tabletop_manipulation_converted_externally_to_rlds": tokyo_pr2_tabletop_manipulation_dataset_transform,
|
||||
"utokyo_xarm_pick_and_place_converted_externally_to_rlds": identity_transform,
|
||||
"utokyo_xarm_bimanual_converted_externally_to_rlds": utokyo_xarm_bimanual_dataset_transform,
|
||||
"robo_net": robo_net_dataset_transform,
|
||||
"berkeley_mvp_converted_externally_to_rlds": berkeley_mvp_dataset_transform,
|
||||
"berkeley_rpt_converted_externally_to_rlds": berkeley_rpt_dataset_transform,
|
||||
"kaist_nonprehensile_converted_externally_to_rlds": kaist_nonprehensible_dataset_transform,
|
||||
"stanford_mask_vit_converted_externally_to_rlds": stanford_mask_vit_dataset_transform,
|
||||
"tokyo_u_lsmo_converted_externally_to_rlds": tokyo_lsmo_dataset_transform,
|
||||
"dlr_sara_pour_converted_externally_to_rlds": identity_transform,
|
||||
"dlr_sara_grid_clamp_converted_externally_to_rlds": dlr_sara_grid_clamp_dataset_transform,
|
||||
"dlr_edan_shared_control_converted_externally_to_rlds": dlr_edan_shared_control_dataset_transform,
|
||||
"asu_table_top_converted_externally_to_rlds": asu_table_top_dataset_transform,
|
||||
"stanford_robocook_converted_externally_to_rlds": robocook_dataset_transform,
|
||||
"imperialcollege_sawyer_wrist_cam": imperial_wristcam_dataset_transform,
|
||||
"iamlab_cmu_pickup_insert_converted_externally_to_rlds": iamlab_pick_insert_dataset_transform,
|
||||
"uiuc_d3field": uiuc_d3field_dataset_transform,
|
||||
"utaustin_mutex": utaustin_mutex_dataset_transform,
|
||||
"berkeley_fanuc_manipulation": berkeley_fanuc_dataset_transform,
|
||||
"cmu_playing_with_food": cmu_playing_with_food_dataset_transform,
|
||||
"cmu_play_fusion": playfusion_dataset_transform,
|
||||
"cmu_stretch": cmu_stretch_dataset_transform,
|
||||
"berkeley_gnm_recon": gnm_dataset_transform,
|
||||
"berkeley_gnm_cory_hall": gnm_dataset_transform,
|
||||
"berkeley_gnm_sac_son": gnm_dataset_transform,
|
||||
"droid": droid_baseact_transform_fn(),
|
||||
"droid_100": droid_baseact_transform_fn(), # first 100 episodes of droid
|
||||
"fmb": fmb_transform,
|
||||
"dobbe": dobbe_dataset_transform,
|
||||
"robo_set": robo_set_dataset_transform,
|
||||
"usc_cloth_sim_converted_externally_to_rlds": identity_transform,
|
||||
"plex_robosuite": identity_transform,
|
||||
"conq_hose_manipulation": identity_transform,
|
||||
"io_ai_tech": identity_transform,
|
||||
"spoc": identity_transform,
|
||||
}
|
||||
@@ -14,16 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
For all datasets in the RLDS format.
|
||||
For https://github.com/google-deepmind/open_x_embodiment (OPENX) datasets.
|
||||
|
||||
NOTE: You need to install tensorflow and tensorflow_datsets before running this script.
|
||||
|
||||
Example:
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir /path/to/data/bridge_dataset/1.0.0/ \
|
||||
--repo-id your_hub/sampled_bridge_data_v2 \
|
||||
--raw-format rlds \
|
||||
--raw-dir /hdd/tensorflow_datasets/bridge_dataset/1.0.0/ \
|
||||
--repo-id youliangtan/sampled_bridge_data_v2 \
|
||||
--raw-format openx_rlds.bridge_orig \
|
||||
--episodes 3 4 5 8 9
|
||||
|
||||
Exact dataset fps defined in openx/config.py, obtained from:
|
||||
@@ -38,10 +35,12 @@ import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import torch
|
||||
import tqdm
|
||||
import yaml
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
@@ -53,6 +52,11 @@ from lerobot.common.datasets.utils import (
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
with open("lerobot/common/datasets/push_dataset_to_hub/openx/configs.yaml") as f:
|
||||
_openx_list = yaml.safe_load(f)
|
||||
|
||||
OPENX_DATASET_CONFIGS = _openx_list["OPENX_DATASET_CONFIGS"]
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
|
||||
|
||||
@@ -104,6 +108,7 @@ def load_from_raw(
|
||||
video: bool,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -131,17 +136,16 @@ def load_from_raw(
|
||||
# we will apply the standardization transform if the dataset_name is provided
|
||||
# if the dataset name is not provided and the goal is to convert any rlds formatted dataset
|
||||
# search for 'image' keys in the observations
|
||||
image_keys = []
|
||||
state_keys = []
|
||||
observation_info = dataset_info.features["steps"]["observation"]
|
||||
for key in observation_info:
|
||||
# check whether the key is for an image or a vector observation
|
||||
if len(observation_info[key].shape) == 3:
|
||||
# only adding uint8 images discards depth images
|
||||
if observation_info[key].dtype == tf.uint8:
|
||||
image_keys.append(key)
|
||||
else:
|
||||
state_keys.append(key)
|
||||
if openx_dataset_name is not None:
|
||||
print(" - applying standardization transform for dataset: ", openx_dataset_name)
|
||||
assert openx_dataset_name in OPENX_STANDARDIZATION_TRANSFORMS
|
||||
transform_fn = OPENX_STANDARDIZATION_TRANSFORMS[openx_dataset_name]
|
||||
dataset = dataset.map(transform_fn)
|
||||
|
||||
image_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["image_obs_keys"]
|
||||
else:
|
||||
obs_keys = dataset_info.features["steps"]["observation"].keys()
|
||||
image_keys = [key for key in obs_keys if "image" in key]
|
||||
|
||||
lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None
|
||||
|
||||
@@ -189,31 +193,50 @@ def load_from_raw(
|
||||
|
||||
num_frames = episode["action"].shape[0]
|
||||
|
||||
ep_dict = {}
|
||||
for key in state_keys:
|
||||
ep_dict[f"observation.{key}"] = tf_to_torch(episode["observation"][key])
|
||||
###########################################################
|
||||
# Handle the episodic data
|
||||
|
||||
ep_dict["action"] = tf_to_torch(episode["action"])
|
||||
ep_dict["next.reward"] = tf_to_torch(episode["reward"]).float()
|
||||
ep_dict["next.done"] = tf_to_torch(episode["is_last"])
|
||||
ep_dict["is_terminal"] = tf_to_torch(episode["is_terminal"])
|
||||
ep_dict["is_first"] = tf_to_torch(episode["is_first"])
|
||||
ep_dict["discount"] = tf_to_torch(episode["discount"])
|
||||
# last step of demonstration is considered done
|
||||
done = torch.zeros(num_frames, dtype=torch.bool)
|
||||
done[-1] = True
|
||||
ep_dict = {}
|
||||
langs = [] # TODO: might be located in "observation"
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
|
||||
# We will create the state observation tensor by stacking the state
|
||||
# obs keys defined in the openx/configs.py
|
||||
if openx_dataset_name is not None:
|
||||
state_obs_keys = OPENX_DATASET_CONFIGS[openx_dataset_name]["state_obs_keys"]
|
||||
# stack the state observations, if is None, pad with zeros
|
||||
states = []
|
||||
for key in state_obs_keys:
|
||||
if key in episode["observation"]:
|
||||
states.append(tf_to_torch(episode["observation"][key]))
|
||||
else:
|
||||
states.append(torch.zeros(num_frames, 1)) # pad with zeros
|
||||
states = torch.cat(states, dim=1)
|
||||
# assert states.shape == (num_frames, 8), f"states shape: {states.shape}"
|
||||
else:
|
||||
states = tf_to_torch(episode["observation"]["state"])
|
||||
|
||||
actions = tf_to_torch(episode["action"])
|
||||
rewards = tf_to_torch(episode["reward"]).float()
|
||||
|
||||
# If lang_key is present, convert the entire tensor at once
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]]
|
||||
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
|
||||
image_array_dict = {key: [] for key in image_keys}
|
||||
langs = [str(x) for x in episode[lang_key]]
|
||||
|
||||
for im_key in image_keys:
|
||||
imgs = episode["observation"][im_key]
|
||||
image_array_dict[im_key] = [tf_img_convert(img) for img in imgs]
|
||||
|
||||
# simple assertions
|
||||
for item in [states, actions, rewards, done]:
|
||||
assert len(item) == num_frames
|
||||
|
||||
###########################################################
|
||||
|
||||
# loop through all cameras
|
||||
for im_key in image_keys:
|
||||
img_key = f"observation.images.{im_key}"
|
||||
@@ -239,6 +262,17 @@ def load_from_raw(
|
||||
else:
|
||||
ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array]
|
||||
|
||||
if lang_key is not None:
|
||||
ep_dict["language_instruction"] = langs
|
||||
|
||||
ep_dict["observation.state"] = states
|
||||
ep_dict["action"] = actions
|
||||
ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps
|
||||
ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames)
|
||||
ep_dict["frame_index"] = torch.arange(0, num_frames, 1)
|
||||
ep_dict["next.reward"] = rewards
|
||||
ep_dict["next.done"] = done
|
||||
|
||||
path_ep_dict = tmp_ep_dicts_dir.joinpath(
|
||||
"ep_dict_" + "0" * (10 - len(str(ep_idx))) + str(ep_idx) + ".pt"
|
||||
)
|
||||
@@ -256,28 +290,30 @@ def load_from_raw(
|
||||
def to_hf_dataset(data_dict, video) -> Dataset:
|
||||
features = {}
|
||||
|
||||
for key in data_dict:
|
||||
# check if vector state obs
|
||||
if key.startswith("observation.") and "observation.images." not in key:
|
||||
features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None))
|
||||
# check if image obs
|
||||
elif "observation.images." in key:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
keys = [key for key in data_dict if "observation.images." in key]
|
||||
for key in keys:
|
||||
if video:
|
||||
features[key] = VideoFrame()
|
||||
else:
|
||||
features[key] = Image()
|
||||
|
||||
features["observation.state"] = Sequence(
|
||||
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.velocity" in data_dict:
|
||||
features["observation.velocity"] = Sequence(
|
||||
length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "observation.effort" in data_dict:
|
||||
features["observation.effort"] = Sequence(
|
||||
length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
if "language_instruction" in data_dict:
|
||||
features["language_instruction"] = Value(dtype="string", id=None)
|
||||
|
||||
features["action"] = Sequence(
|
||||
length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)
|
||||
)
|
||||
|
||||
features["is_terminal"] = Value(dtype="bool", id=None)
|
||||
features["is_first"] = Value(dtype="bool", id=None)
|
||||
features["discount"] = Value(dtype="float32", id=None)
|
||||
|
||||
features["episode_index"] = Value(dtype="int64", id=None)
|
||||
features["frame_index"] = Value(dtype="int64", id=None)
|
||||
features["timestamp"] = Value(dtype="float32", id=None)
|
||||
@@ -297,8 +333,19 @@ def from_raw_to_lerobot_format(
|
||||
video: bool = True,
|
||||
episodes: list[int] | None = None,
|
||||
encoding: dict | None = None,
|
||||
openx_dataset_name: str | None = None,
|
||||
):
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding)
|
||||
"""This is a test impl for rlds conversion"""
|
||||
if openx_dataset_name is None:
|
||||
# set a default rlds frame rate if the dataset is not from openx
|
||||
fps = 30
|
||||
elif "fps" not in OPENX_DATASET_CONFIGS[openx_dataset_name]:
|
||||
raise ValueError(
|
||||
"fps for this dataset is not specified in openx/configs.py yet," "means it is not yet tested"
|
||||
)
|
||||
fps = OPENX_DATASET_CONFIGS[openx_dataset_name]["fps"]
|
||||
|
||||
data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, encoding, openx_dataset_name)
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
info = {
|
||||
|
||||
@@ -13,15 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -479,7 +476,6 @@ def create_lerobot_dataset_card(
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
|
||||
if tags:
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
@@ -497,66 +493,8 @@ def create_lerobot_dataset_card(
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_str=card_template,
|
||||
template_path="./lerobot/common/datasets/card_template.md",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
@@ -159,11 +159,11 @@ DATASETS = {
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
|
||||
@@ -131,7 +131,8 @@ def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
# vcodec: str = "libsvtav1",
|
||||
vcodec: str = "libx264",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
|
||||
@@ -14,13 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from mani_skill.utils import common
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
@@ -34,10 +30,6 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
@@ -64,86 +56,3 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(
|
||||
observation, use_torch=True, device=self.base_env.device
|
||||
)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
@@ -28,9 +28,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
@@ -53,8 +50,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -65,38 +60,3 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
@@ -25,7 +25,6 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -108,6 +107,8 @@ class Logger:
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
@@ -231,7 +232,7 @@ class Logger:
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
if not isinstance(v, (int, float, str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
|
||||
@@ -140,25 +140,25 @@ class ACTPolicy(
|
||||
batch = self.normalize_targets(batch)
|
||||
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
|
||||
|
||||
l1_loss = (
|
||||
F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
).mean()
|
||||
bsize = actions_hat.shape[0]
|
||||
l1_loss = F.l1_loss(batch["action"], actions_hat, reduction="none")
|
||||
l1_loss = l1_loss * ~batch["action_is_pad"].unsqueeze(-1)
|
||||
l1_loss = l1_loss.view(bsize, -1).mean(dim=1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
out_dict = {}
|
||||
out_dict["l1_loss"] = l1_loss
|
||||
if self.config.use_vae:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
# (See App. B of https://arxiv.org/abs/1312.6114 for more details).
|
||||
mean_kld = (
|
||||
(-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean()
|
||||
)
|
||||
loss_dict["kld_loss"] = mean_kld.item()
|
||||
loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight
|
||||
kld_loss = (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1)
|
||||
out_dict["loss"] = l1_loss + kld_loss * self.config.kl_weight
|
||||
else:
|
||||
loss_dict["loss"] = l1_loss
|
||||
out_dict["loss"] = l1_loss
|
||||
|
||||
return loss_dict
|
||||
out_dict["action"] = self.unnormalize_outputs({"action": actions_hat})["action"]
|
||||
return out_dict
|
||||
|
||||
|
||||
class ACTTemporalEnsembler:
|
||||
|
||||
@@ -66,11 +66,6 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
@@ -90,10 +85,10 @@ def make_policy(
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
# raise ValueError(
|
||||
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
# )
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
)
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -1,151 +0,0 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x):
|
||||
if self.config.num_classes == 2:
|
||||
return (self.forward(x).probabilities > 0.5).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
@@ -1,23 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
@@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = False
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
policy_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
}
|
||||
)
|
||||
@@ -1,571 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from typing import Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = {}
|
||||
for outer_key, inner_dict in config.output_normalization_params.items():
|
||||
output_normalization_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
output_normalization_params[outer_key][key] = torch.tensor(value)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config)
|
||||
encoder_actor = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
target_critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
target_critic_net = Critic(
|
||||
encoder=encoder_critic,
|
||||
network=MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
target_critic_nets.append(target_critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(
|
||||
critics=critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target = create_critic_ensemble(
|
||||
critics=target_critic_nets, num_critics=config.num_critics, device=device
|
||||
)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
device=device,
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
# TODO: Handle the case where the temparameter is a fixed
|
||||
self.log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = torch.stack([critic(observations, actions) for critic in critics])
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False):
|
||||
for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
self.output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(self.device) for k, v in observations.items()}
|
||||
actions = actions.to(self.device)
|
||||
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
device: str = "cpu",
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = observations if self.encoder is None else self.encoder(observations)
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.camera_number = config.camera_number
|
||||
self.aggregation_size: int = 0
|
||||
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
sequential=nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(
|
||||
in_features=np.prod(out_shape) * self.camera_number, out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
# TODO(ke-wang): currently average over all features, concatenate all features maybe a better way
|
||||
# return torch.stack(feat, dim=0).mean(0)
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cpu") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}"
|
||||
return nn.ModuleList(critics).to(device)
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
90
lerobot/common/robot_devices/cameras/reachy2.py
Normal file
90
lerobot/common/robot_devices/cameras/reachy2.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Wrapper for Reachy2 camera from sdk
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from reachy2_sdk.media.camera import CameraView
|
||||
from reachy2_sdk.media.camera_manager import CameraManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReachyCameraConfig:
|
||||
fps: int | None = None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.color_mode not in ["rgb", "bgr"]:
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
|
||||
class ReachyCamera:
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
name: str,
|
||||
image_type: str,
|
||||
config: ReachyCameraConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if config is None:
|
||||
config = ReachyCameraConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
config = replace(config, **kwargs)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.fps = config.fps
|
||||
self.image_type = image_type
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.cam_manager = None
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
def connect(self):
|
||||
if not self.is_connected:
|
||||
self.cam_manager = CameraManager(host=self.host, port=self.port)
|
||||
self.cam_manager.initialize_cameras() # FIXME: maybe we should not re-initialize
|
||||
self.is_connected = True
|
||||
|
||||
def read(self) -> np.ndarray:
|
||||
if not self.is_connected:
|
||||
self.connect()
|
||||
|
||||
frame = None
|
||||
|
||||
if self.name == "teleop" and hasattr(self.cam_manager, "teleop"):
|
||||
if self.image_type == "left":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.LEFT)
|
||||
elif self.image_type == "right":
|
||||
frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT)
|
||||
elif self.name == "depth" and hasattr(self.cam_manager, "depth"):
|
||||
if self.image_type == "depth":
|
||||
frame = self.cam_manager.depth.get_depth_frame()
|
||||
elif self.image_type == "rgb":
|
||||
frame = self.cam_manager.depth.get_frame()
|
||||
|
||||
if frame is None:
|
||||
return None
|
||||
|
||||
if frame is not None and self.config.color_mode == "rgb":
|
||||
img, timestamp = frame
|
||||
frame = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB), timestamp)
|
||||
|
||||
return frame
|
||||
@@ -11,7 +11,6 @@ from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
@@ -47,7 +46,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
log_dt("dt", dt_s)
|
||||
|
||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||
if not robot.robot_type.startswith("stretch"):
|
||||
if not robot.robot_type.lower().startswith(("stretch", "reachy")):
|
||||
for name in robot.leader_arms:
|
||||
key = f"read_leader_{name}_pos_dt_s"
|
||||
if key in robot.logs:
|
||||
@@ -121,22 +120,14 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -161,13 +152,6 @@ def init_keyboard_listener(assign_rewards=False):
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -200,7 +184,7 @@ def init_policy(pretrained_policy_name_or_path, policy_overrides):
|
||||
def warmup_record(
|
||||
robot,
|
||||
events,
|
||||
enable_teleoperation,
|
||||
enable_teloperation,
|
||||
warmup_time_s,
|
||||
display_cameras,
|
||||
fps,
|
||||
@@ -211,7 +195,7 @@ def warmup_record(
|
||||
display_cameras=display_cameras,
|
||||
events=events,
|
||||
fps=fps,
|
||||
teleoperate=enable_teleoperation,
|
||||
teleoperate=enable_teloperation,
|
||||
)
|
||||
|
||||
|
||||
@@ -288,8 +272,6 @@ def control_loop(
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
@@ -319,8 +301,6 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
@@ -333,14 +313,6 @@ def reset_environment(robot, events, reset_time_s):
|
||||
break
|
||||
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
@@ -371,21 +343,17 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, features_from_robot),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"])
|
||||
diff = DeepDiff(dataset_value, present_value)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
|
||||
317
lerobot/common/robot_devices/robots/reachy2.py
Normal file
317
lerobot/common/robot_devices/robots/reachy2.py
Normal file
@@ -0,0 +1,317 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The Pollen Robotics team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
|
||||
|
||||
REACHY_MOTORS = [
|
||||
"neck_yaw.pos",
|
||||
"neck_pitch.pos",
|
||||
"neck_roll.pos",
|
||||
"r_shoulder_pitch.pos",
|
||||
"r_shoulder_roll.pos",
|
||||
"r_elbow_yaw.pos",
|
||||
"r_elbow_pitch.pos",
|
||||
"r_wrist_roll.pos",
|
||||
"r_wrist_pitch.pos",
|
||||
"r_wrist_yaw.pos",
|
||||
"r_gripper.pos",
|
||||
"l_shoulder_pitch.pos",
|
||||
"l_shoulder_roll.pos",
|
||||
"l_elbow_yaw.pos",
|
||||
"l_elbow_pitch.pos",
|
||||
"l_wrist_roll.pos",
|
||||
"l_wrist_pitch.pos",
|
||||
"l_wrist_yaw.pos",
|
||||
"l_gripper.pos",
|
||||
"mobile_base.vx",
|
||||
"mobile_base.vy",
|
||||
"mobile_base.vtheta",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReachyRobotConfig:
|
||||
robot_type: str | None = "reachy2"
|
||||
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
|
||||
ip_address: str | None = "172.17.135.207"
|
||||
# ip_address: str | None = "192.168.0.197"
|
||||
# ip_address: str | None = "localhost"
|
||||
|
||||
|
||||
class ReachyRobot:
|
||||
"""Wrapper of ReachySDK"""
|
||||
|
||||
def __init__(self, config: ReachyRobotConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = ReachyRobotConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
|
||||
self.robot_type = self.config.robot_type
|
||||
self.cameras = self.config.cameras
|
||||
self.has_camera = True
|
||||
self.num_cameras = len(self.cameras)
|
||||
self.is_connected = False
|
||||
self.teleop = None
|
||||
self.logs = {}
|
||||
self.reachy = None
|
||||
self.mobile_base_available = False
|
||||
|
||||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
motors = REACHY_MOTORS
|
||||
# if self.mobile_base_available:
|
||||
# motors += REACHY_MOBILE_BASE
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": motors,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": motors,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
def connect(self) -> None:
|
||||
self.reachy = ReachySDK(host=self.config.ip_address)
|
||||
print("Connecting to Reachy")
|
||||
self.reachy.connect()
|
||||
self.is_connected = self.reachy.is_connected
|
||||
if not self.is_connected:
|
||||
print(
|
||||
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
|
||||
)
|
||||
raise ConnectionError()
|
||||
# self.reachy.turn_on()
|
||||
print(self.cameras)
|
||||
if self.cameras is not None:
|
||||
for name in self.cameras:
|
||||
print(f"Connecting camera: {name}")
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.mobile_base_available = self.reachy.mobile_base is not None
|
||||
|
||||
def run_calibration(self):
|
||||
pass
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not record_data:
|
||||
return
|
||||
action = {}
|
||||
action["neck_roll.pos"] = self.reachy.head.neck.roll.goal_position
|
||||
action["neck_pitch.pos"] = self.reachy.head.neck.pitch.goal_position
|
||||
action["neck_yaw.pos"] = self.reachy.head.neck.yaw.goal_position
|
||||
|
||||
action["r_shoulder_pitch.pos"] = self.reachy.r_arm.shoulder.pitch.goal_position
|
||||
action["r_shoulder_roll.pos"] = self.reachy.r_arm.shoulder.roll.goal_position
|
||||
action["r_elbow_yaw.pos"] = self.reachy.r_arm.elbow.yaw.goal_position
|
||||
action["r_elbow_pitch.pos"] = self.reachy.r_arm.elbow.pitch.goal_position
|
||||
action["r_wrist_roll.pos"] = self.reachy.r_arm.wrist.roll.goal_position
|
||||
action["r_wrist_pitch.pos"] = self.reachy.r_arm.wrist.pitch.goal_position
|
||||
action["r_wrist_yaw.pos"] = self.reachy.r_arm.wrist.yaw.goal_position
|
||||
action["r_gripper.pos"] = self.reachy.r_arm.gripper.opening
|
||||
|
||||
action["l_shoulder_pitch.pos"] = self.reachy.l_arm.shoulder.pitch.goal_position
|
||||
action["l_shoulder_roll.pos"] = self.reachy.l_arm.shoulder.roll.goal_position
|
||||
action["l_elbow_yaw.pos"] = self.reachy.l_arm.elbow.yaw.goal_position
|
||||
action["l_elbow_pitch.pos"] = self.reachy.l_arm.elbow.pitch.goal_position
|
||||
action["l_wrist_roll.pos"] = self.reachy.l_arm.wrist.roll.goal_position
|
||||
action["l_wrist_pitch.pos"] = self.reachy.l_arm.wrist.pitch.goal_position
|
||||
action["l_wrist_yaw.pos"] = self.reachy.l_arm.wrist.yaw.goal_position
|
||||
action["l_gripper.pos"] = self.reachy.l_arm.gripper.opening
|
||||
|
||||
if self.mobile_base_available:
|
||||
last_cmd_vel = self.reachy.mobile_base.last_cmd_vel
|
||||
action["mobile_base_x.vel"] = last_cmd_vel["x"]
|
||||
action["mobile_base_y.vel"] = last_cmd_vel["y"]
|
||||
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
|
||||
else:
|
||||
action["mobile_base_x.vel"] = 0
|
||||
action["mobile_base_y.vel"] = 0
|
||||
action["mobile_base_theta.vel"] = 0
|
||||
|
||||
dtype = self.motor_features["action"]["dtype"]
|
||||
action = np.array(list(action.values()), dtype=dtype)
|
||||
# action = torch.as_tensor(list(action.values()))
|
||||
|
||||
obs_dict = self.capture_observation()
|
||||
action_dict = {}
|
||||
action_dict["action"] = action
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def get_state(self) -> dict:
|
||||
# neck roll, pitch, yaw
|
||||
# r_shoulder_pitch, r_shoulder_roll, r_elbow_yaw, r_elbow_pitch, r_wrist_roll, r_wrist_pitch, r_wrist_yaw, r_gripper
|
||||
# l_shoulder_pitch, l_shoulder_roll, l_elbow_yaw, l_elbow_pitch, l_wrist_roll, l_wrist_pitch, l_wrist_yaw, l_gripper
|
||||
# mobile base x, y, theta
|
||||
if self.is_connected:
|
||||
if self.mobile_base_available:
|
||||
odometry = self.reachy.mobile_base.odometry
|
||||
else:
|
||||
odometry = {"x": 0, "y": 0, "theta": 0, "vx": 0, "vy": 0, "vtheta": 0}
|
||||
return {
|
||||
"neck_yaw.pos": self.reachy.head.neck.yaw.present_position,
|
||||
"neck_pitch.pos": self.reachy.head.neck.pitch.present_position,
|
||||
"neck_roll.pos": self.reachy.head.neck.roll.present_position,
|
||||
"r_shoulder_pitch.pos": self.reachy.r_arm.shoulder.pitch.present_position,
|
||||
"r_shoulder_roll.pos": self.reachy.r_arm.shoulder.roll.present_position,
|
||||
"r_elbow_yaw.pos": self.reachy.r_arm.elbow.yaw.present_position,
|
||||
"r_elbow_pitch.pos": self.reachy.r_arm.elbow.pitch.present_position,
|
||||
"r_wrist_roll.pos": self.reachy.r_arm.wrist.roll.present_position,
|
||||
"r_wrist_pitch.pos": self.reachy.r_arm.wrist.pitch.present_position,
|
||||
"r_wrist_yaw.pos": self.reachy.r_arm.wrist.yaw.present_position,
|
||||
"r_gripper.pos": self.reachy.r_arm.gripper.present_position,
|
||||
"l_shoulder_pitch.pos": self.reachy.l_arm.shoulder.pitch.present_position,
|
||||
"l_shoulder_roll.pos": self.reachy.l_arm.shoulder.roll.present_position,
|
||||
"l_elbow_yaw.pos": self.reachy.l_arm.elbow.yaw.present_position,
|
||||
"l_elbow_pitch.pos": self.reachy.l_arm.elbow.pitch.present_position,
|
||||
"l_wrist_roll.pos": self.reachy.l_arm.wrist.roll.present_position,
|
||||
"l_wrist_pitch.pos": self.reachy.l_arm.wrist.pitch.present_position,
|
||||
"l_wrist_yaw.pos": self.reachy.l_arm.wrist.yaw.present_position,
|
||||
"l_gripper.pos": self.reachy.l_arm.gripper.present_position,
|
||||
"mobile_base.vx": odometry["vx"],
|
||||
"mobile_base.vy": odometry["vy"],
|
||||
"mobile_base.vtheta": odometry["vtheta"],
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
if self.is_connected:
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
dtype = self.motor_features["observation.state"]["dtype"]
|
||||
state = np.array(list(state.values()), dtype=dtype)
|
||||
# state = torch.as_tensor(list(state.values()))
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
# before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
||||
# print(f'name: {name} img: {images[name]}')
|
||||
if images[name] is not None:
|
||||
# images[name] = copy(images[name][0]) # seems like I need to copy?
|
||||
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict
|
||||
else:
|
||||
return {}
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
self.reachy.head.neck.yaw.goal_position = float(action[0])
|
||||
self.reachy.head.neck.pitch.goal_position = float(action[1])
|
||||
self.reachy.head.neck.roll.goal_position = float(action[2])
|
||||
|
||||
self.reachy.r_arm.shoulder.pitch.goal_position = float(action[3])
|
||||
self.reachy.r_arm.shoulder.roll.goal_position = float(action[4])
|
||||
self.reachy.r_arm.elbow.yaw.goal_position = float(action[5])
|
||||
self.reachy.r_arm.elbow.pitch.goal_position = float(action[6])
|
||||
self.reachy.r_arm.wrist.roll.goal_position = float(action[7])
|
||||
self.reachy.r_arm.wrist.roll.goal_position = float(action[8])
|
||||
self.reachy.r_arm.wrist.yaw.goal_position = float(action[9])
|
||||
self.reachy.r_arm.gripper.set_opening(float(action[10]))
|
||||
|
||||
self.reachy.l_arm.shoulder.pitch.goal_position = float(action[11])
|
||||
self.reachy.l_arm.shoulder.roll.goal_position = float(action[12])
|
||||
self.reachy.l_arm.elbow.yaw.goal_position = float(action[13])
|
||||
self.reachy.l_arm.elbow.pitch.goal_position = float(action[14])
|
||||
self.reachy.l_arm.wrist.roll.goal_position = float(action[15])
|
||||
self.reachy.l_arm.wrist.roll.goal_position = float(action[16])
|
||||
self.reachy.l_arm.wrist.yaw.goal_position = float(action[17])
|
||||
self.reachy.l_arm.gripper.set_opening(float(action[18]))
|
||||
|
||||
s = time.time()
|
||||
self.reachy.send_goal_positions(check_positions=False)
|
||||
print("send_goal_positions", time.time() - s)
|
||||
|
||||
if self.mobile_base_available:
|
||||
self.reachy.mobile_base.set_goal_speed(action[19], action[20], action[21])
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
# TODO: what shape is the action tensor?
|
||||
# 7 dofs per arm (x2)
|
||||
# 1 dof per gripper (x2)
|
||||
# 3 dofs for the neck
|
||||
# 3 dofs for the mobile base (x, y, theta)
|
||||
# 7+7+1+1+3+3 = 22
|
||||
return action
|
||||
|
||||
def print_logs(self) -> None:
|
||||
pass
|
||||
|
||||
def disconnect(self) -> None:
|
||||
print("Disconnecting")
|
||||
self.is_connected = False
|
||||
print("Turn off")
|
||||
# self.reachy.turn_off_smoothly()
|
||||
# self.reachy.turn_off()
|
||||
print("\t turn off done")
|
||||
self.reachy.disconnect()
|
||||
342
lerobot/common/robot_devices/robots/reachy2_manipulator.py
Normal file
342
lerobot/common/robot_devices/robots/reachy2_manipulator.py
Normal file
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The Pollen Robotics team and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from lerobot.common.robot_devices.cameras.reachy2 import ReachyCamera
|
||||
from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus
|
||||
from reachy2_sdk import ReachySDK
|
||||
|
||||
REACHY_MOTORS = [
|
||||
"neck_yaw.pos",
|
||||
"neck_pitch.pos",
|
||||
"neck_roll.pos",
|
||||
"r_shoulder_pitch.pos",
|
||||
"r_shoulder_roll.pos",
|
||||
"r_elbow_yaw.pos",
|
||||
"r_elbow_pitch.pos",
|
||||
"r_wrist_roll.pos",
|
||||
"r_wrist_pitch.pos",
|
||||
"r_wrist_yaw.pos",
|
||||
"r_gripper.pos",
|
||||
"l_shoulder_pitch.pos",
|
||||
"l_shoulder_roll.pos",
|
||||
"l_elbow_yaw.pos",
|
||||
"l_elbow_pitch.pos",
|
||||
"l_wrist_roll.pos",
|
||||
"l_wrist_pitch.pos",
|
||||
"l_wrist_yaw.pos",
|
||||
"l_gripper.pos",
|
||||
"mobile_base.vx",
|
||||
"mobile_base.vy",
|
||||
"mobile_base.vtheta",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReachyManipulatorRobotConfig:
|
||||
robot_type: str | None = "reachy2"
|
||||
cameras: dict[str, ReachyCamera] = field(default_factory=lambda: {})
|
||||
ip_address: str | None = "172.17.135.207"
|
||||
# ip_address: str | None = "192.168.0.197"
|
||||
# ip_address: str | None = "localhost"
|
||||
|
||||
|
||||
class ReachyManipulatorRobot:
|
||||
"""Wrapper of ReachySDK"""
|
||||
|
||||
def __init__(self, config: ReachyRobotManipulatorConfig | None = None, **kwargs):
|
||||
if config is None:
|
||||
config = ReachyRobotManipulatorConfig()
|
||||
|
||||
# Overwrite config arguments using kwargs
|
||||
self.config = replace(config, **kwargs)
|
||||
|
||||
self.robot_type = self.config.robot_type
|
||||
self.cameras = self.config.cameras
|
||||
self.has_camera = True
|
||||
self.num_cameras = len(self.cameras)
|
||||
self.is_connected = False
|
||||
self.teleop = None
|
||||
self.logs = {}
|
||||
self.reachy = None
|
||||
self.mobile_base_available = False
|
||||
|
||||
self.state_keys = None
|
||||
self.action_keys = None
|
||||
|
||||
self.leader_arm = FeetechMotorsBus(config.leader_arm.port, config.leader_arm.motors)
|
||||
self.leader_calib_dir=config.leader_arm.calibration_dir
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
motors = REACHY_MOTORS
|
||||
# if self.mobile_base_available:
|
||||
# motors += REACHY_MOBILE_BASE
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": motors,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": motors,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
def connect(self) -> None:
|
||||
self.reachy = ReachySDK(host=self.config.ip_address)
|
||||
print("Connecting to Reachy")
|
||||
self.reachy.connect()
|
||||
self.is_connected = self.reachy.is_connected
|
||||
if not self.is_connected:
|
||||
print(
|
||||
f"Cannot connect to Reachy at address {self.config.ip_address}. Maybe a connection already exists."
|
||||
)
|
||||
raise ConnectionError()
|
||||
# self.reachy.turn_on()
|
||||
print(self.cameras)
|
||||
if self.cameras is not None:
|
||||
for name in self.cameras:
|
||||
print(f"Connecting camera: {name}")
|
||||
self.cameras[name].connect()
|
||||
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||
|
||||
if not self.is_connected:
|
||||
print("Could not connect to the cameras, check that all cameras are plugged-in.")
|
||||
raise ConnectionError()
|
||||
|
||||
self.mobile_base_available = self.reachy.mobile_base is not None
|
||||
|
||||
print("Connecting to leader arm")
|
||||
self.leader_arm.connect()
|
||||
|
||||
with open(self.leader_arm.calibration_dir) as f:
|
||||
self.leader_arm.calibration = json.load(f)
|
||||
self.leader_arm.set_calibration(self.leader_arm.calibration)
|
||||
self.leader_arm.apply_calibration()
|
||||
|
||||
def run_calibration(self):
|
||||
pass
|
||||
|
||||
def teleop_step(
|
||||
self, record_data=False
|
||||
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
if not record_data:
|
||||
return
|
||||
|
||||
|
||||
#get leader arm
|
||||
leader_pos = {}
|
||||
for name in self.leader_arm:
|
||||
before_lread_t = time.perf_counter()
|
||||
leader_pos[name] = self.leader_arm[name].read("Present_Position")
|
||||
leader_pos[name] = torch.from_numpy(leader_pos[name])
|
||||
self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t
|
||||
|
||||
|
||||
#TODO leader arm FK
|
||||
#TODO senf task space poses
|
||||
action = {}
|
||||
action["neck_roll.pos"] = self.reachy.head.neck.roll.goal_position
|
||||
action["neck_pitch.pos"] = self.reachy.head.neck.pitch.goal_position
|
||||
action["neck_yaw.pos"] = self.reachy.head.neck.yaw.goal_position
|
||||
|
||||
action["r_shoulder_pitch.pos"] = self.reachy.r_arm.shoulder.pitch.goal_position
|
||||
action["r_shoulder_roll.pos"] = self.reachy.r_arm.shoulder.roll.goal_position
|
||||
action["r_elbow_yaw.pos"] = self.reachy.r_arm.elbow.yaw.goal_position
|
||||
action["r_elbow_pitch.pos"] = self.reachy.r_arm.elbow.pitch.goal_position
|
||||
action["r_wrist_roll.pos"] = self.reachy.r_arm.wrist.roll.goal_position
|
||||
action["r_wrist_pitch.pos"] = self.reachy.r_arm.wrist.pitch.goal_position
|
||||
action["r_wrist_yaw.pos"] = self.reachy.r_arm.wrist.yaw.goal_position
|
||||
action["r_gripper.pos"] = self.reachy.r_arm.gripper.opening
|
||||
|
||||
action["l_shoulder_pitch.pos"] = self.reachy.l_arm.shoulder.pitch.goal_position
|
||||
action["l_shoulder_roll.pos"] = self.reachy.l_arm.shoulder.roll.goal_position
|
||||
action["l_elbow_yaw.pos"] = self.reachy.l_arm.elbow.yaw.goal_position
|
||||
action["l_elbow_pitch.pos"] = self.reachy.l_arm.elbow.pitch.goal_position
|
||||
action["l_wrist_roll.pos"] = self.reachy.l_arm.wrist.roll.goal_position
|
||||
action["l_wrist_pitch.pos"] = self.reachy.l_arm.wrist.pitch.goal_position
|
||||
action["l_wrist_yaw.pos"] = self.reachy.l_arm.wrist.yaw.goal_position
|
||||
action["l_gripper.pos"] = self.reachy.l_arm.gripper.opening
|
||||
|
||||
if self.mobile_base_available:
|
||||
last_cmd_vel = self.reachy.mobile_base.last_cmd_vel
|
||||
action["mobile_base_x.vel"] = last_cmd_vel["x"]
|
||||
action["mobile_base_y.vel"] = last_cmd_vel["y"]
|
||||
action["mobile_base_theta.vel"] = last_cmd_vel["theta"]
|
||||
else:
|
||||
action["mobile_base_x.vel"] = 0
|
||||
action["mobile_base_y.vel"] = 0
|
||||
action["mobile_base_theta.vel"] = 0
|
||||
|
||||
dtype = self.motor_features["action"]["dtype"]
|
||||
action = np.array(list(action.values()), dtype=dtype)
|
||||
# action = torch.as_tensor(list(action.values()))
|
||||
|
||||
obs_dict = self.capture_observation()
|
||||
action_dict = {}
|
||||
action_dict["action"] = action
|
||||
|
||||
return obs_dict, action_dict
|
||||
|
||||
def get_state(self) -> dict:
|
||||
# neck roll, pitch, yaw
|
||||
# r_shoulder_pitch, r_shoulder_roll, r_elbow_yaw, r_elbow_pitch, r_wrist_roll, r_wrist_pitch, r_wrist_yaw, r_gripper
|
||||
# l_shoulder_pitch, l_shoulder_roll, l_elbow_yaw, l_elbow_pitch, l_wrist_roll, l_wrist_pitch, l_wrist_yaw, l_gripper
|
||||
# mobile base x, y, theta
|
||||
if self.is_connected:
|
||||
if self.mobile_base_available:
|
||||
odometry = self.reachy.mobile_base.odometry
|
||||
else:
|
||||
odometry = {"x": 0, "y": 0, "theta": 0, "vx": 0, "vy": 0, "vtheta": 0}
|
||||
return {
|
||||
"neck_yaw.pos": self.reachy.head.neck.yaw.present_position,
|
||||
"neck_pitch.pos": self.reachy.head.neck.pitch.present_position,
|
||||
"neck_roll.pos": self.reachy.head.neck.roll.present_position,
|
||||
"r_shoulder_pitch.pos": self.reachy.r_arm.shoulder.pitch.present_position,
|
||||
"r_shoulder_roll.pos": self.reachy.r_arm.shoulder.roll.present_position,
|
||||
"r_elbow_yaw.pos": self.reachy.r_arm.elbow.yaw.present_position,
|
||||
"r_elbow_pitch.pos": self.reachy.r_arm.elbow.pitch.present_position,
|
||||
"r_wrist_roll.pos": self.reachy.r_arm.wrist.roll.present_position,
|
||||
"r_wrist_pitch.pos": self.reachy.r_arm.wrist.pitch.present_position,
|
||||
"r_wrist_yaw.pos": self.reachy.r_arm.wrist.yaw.present_position,
|
||||
"r_gripper.pos": self.reachy.r_arm.gripper.present_position,
|
||||
"l_shoulder_pitch.pos": self.reachy.l_arm.shoulder.pitch.present_position,
|
||||
"l_shoulder_roll.pos": self.reachy.l_arm.shoulder.roll.present_position,
|
||||
"l_elbow_yaw.pos": self.reachy.l_arm.elbow.yaw.present_position,
|
||||
"l_elbow_pitch.pos": self.reachy.l_arm.elbow.pitch.present_position,
|
||||
"l_wrist_roll.pos": self.reachy.l_arm.wrist.roll.present_position,
|
||||
"l_wrist_pitch.pos": self.reachy.l_arm.wrist.pitch.present_position,
|
||||
"l_wrist_yaw.pos": self.reachy.l_arm.wrist.yaw.present_position,
|
||||
"l_gripper.pos": self.reachy.l_arm.gripper.present_position,
|
||||
"mobile_base.vx": odometry["vx"],
|
||||
"mobile_base.vy": odometry["vy"],
|
||||
"mobile_base.vtheta": odometry["vtheta"],
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
def capture_observation(self) -> dict:
|
||||
if self.is_connected:
|
||||
before_read_t = time.perf_counter()
|
||||
state = self.get_state()
|
||||
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||
|
||||
if self.state_keys is None:
|
||||
self.state_keys = list(state)
|
||||
|
||||
dtype = self.motor_features["observation.state"]["dtype"]
|
||||
state = np.array(list(state.values()), dtype=dtype)
|
||||
# state = torch.as_tensor(list(state.values()))
|
||||
|
||||
# Capture images from cameras
|
||||
images = {}
|
||||
for name in self.cameras:
|
||||
# before_camread_t = time.perf_counter()
|
||||
images[name] = self.cameras[name].read() # Reachy cameras read() is not blocking?
|
||||
# print(f'name: {name} img: {images[name]}')
|
||||
if images[name] is not None:
|
||||
# images[name] = copy(images[name][0]) # seems like I need to copy?
|
||||
images[name] = torch.from_numpy(copy(images[name][0])) # seems like I need to copy?
|
||||
self.logs[f"read_camera_{name}_dt_s"] = images[name][1] # full timestamp, TODO dt
|
||||
|
||||
# Populate output dictionnaries
|
||||
obs_dict = {}
|
||||
obs_dict["observation.state"] = state
|
||||
for name in self.cameras:
|
||||
obs_dict[f"observation.images.{name}"] = images[name]
|
||||
|
||||
return obs_dict
|
||||
else:
|
||||
return {}
|
||||
|
||||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||
if not self.is_connected:
|
||||
raise ConnectionError()
|
||||
|
||||
self.reachy.head.neck.yaw.goal_position = float(action[0])
|
||||
self.reachy.head.neck.pitch.goal_position = float(action[1])
|
||||
self.reachy.head.neck.roll.goal_position = float(action[2])
|
||||
|
||||
self.reachy.r_arm.shoulder.pitch.goal_position = float(action[3])
|
||||
self.reachy.r_arm.shoulder.roll.goal_position = float(action[4])
|
||||
self.reachy.r_arm.elbow.yaw.goal_position = float(action[5])
|
||||
self.reachy.r_arm.elbow.pitch.goal_position = float(action[6])
|
||||
self.reachy.r_arm.wrist.roll.goal_position = float(action[7])
|
||||
self.reachy.r_arm.wrist.roll.goal_position = float(action[8])
|
||||
self.reachy.r_arm.wrist.yaw.goal_position = float(action[9])
|
||||
self.reachy.r_arm.gripper.set_opening(float(action[10]))
|
||||
|
||||
self.reachy.l_arm.shoulder.pitch.goal_position = float(action[11])
|
||||
self.reachy.l_arm.shoulder.roll.goal_position = float(action[12])
|
||||
self.reachy.l_arm.elbow.yaw.goal_position = float(action[13])
|
||||
self.reachy.l_arm.elbow.pitch.goal_position = float(action[14])
|
||||
self.reachy.l_arm.wrist.roll.goal_position = float(action[15])
|
||||
self.reachy.l_arm.wrist.roll.goal_position = float(action[16])
|
||||
self.reachy.l_arm.wrist.yaw.goal_position = float(action[17])
|
||||
self.reachy.l_arm.gripper.set_opening(float(action[18]))
|
||||
|
||||
s = time.time()
|
||||
self.reachy.send_goal_positions(check_positions=False)
|
||||
print("send_goal_positions", time.time() - s)
|
||||
|
||||
if self.mobile_base_available:
|
||||
self.reachy.mobile_base.set_goal_speed(action[19], action[20], action[21])
|
||||
self.reachy.mobile_base.send_speed_command()
|
||||
|
||||
# TODO: what shape is the action tensor?
|
||||
# 7 dofs per arm (x2)
|
||||
# 1 dof per gripper (x2)
|
||||
# 3 dofs for the neck
|
||||
# 3 dofs for the mobile base (x, y, theta)
|
||||
# 7+7+1+1+3+3 = 22
|
||||
return action
|
||||
|
||||
def print_logs(self) -> None:
|
||||
pass
|
||||
|
||||
def disconnect(self) -> None:
|
||||
print("Disconnecting")
|
||||
self.is_connected = False
|
||||
print("Turn off")
|
||||
# self.reachy.turn_off_smoothly()
|
||||
# self.reachy.turn_off()
|
||||
print("\t turn off done")
|
||||
self.reachy.disconnect()
|
||||
@@ -1,48 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/pick_place_lego_cube_1
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 5
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_keys: ["observation.images.top", "observation.images.wrist"]
|
||||
label_key: "next.reward"
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier/pick_place_lego_cube_1"
|
||||
model_name: "facebook/convnext-base-224"
|
||||
model_type: "cnn"
|
||||
num_cameras: 2 # Has to be len(training.image_keys)
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "outputs/classifier"
|
||||
@@ -1,97 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: null
|
||||
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 5000
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 64, 64]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: null
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 1
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -1,89 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 128
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 50000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
# image_encoder_hidden_dim: 32
|
||||
discount: 0.99
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -10,7 +10,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem58760430441
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
@@ -23,7 +23,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem585A0083391
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
|
||||
50
lerobot/configs/robot/reachy2.yaml
Normal file
50
lerobot/configs/robot/reachy2.yaml
Normal file
@@ -0,0 +1,50 @@
|
||||
# [Reachy2 from Pollen Robotics](https://www.pollen-robotics.com)
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[reachy2]"`
|
||||
# With poetry: `poetry install --sync --extras "reachy2"`
|
||||
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobot
|
||||
robot_type: reachy2
|
||||
|
||||
cameras:
|
||||
head_left:
|
||||
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
name: teleop
|
||||
host: 172.17.134.85
|
||||
# host: 192.168.0.197
|
||||
# host: localhost
|
||||
port: 50065
|
||||
fps: 30
|
||||
width: 960
|
||||
height: 720
|
||||
image_type: left
|
||||
# head_right:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: teleop
|
||||
# host: 172.17.135.207
|
||||
# port: 50065
|
||||
# image_type: right
|
||||
# fps: 30
|
||||
# width: 960
|
||||
# height: 720
|
||||
# torso_rgb:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: depth
|
||||
# host: 172.17.135.207
|
||||
# # host: localhost
|
||||
# port: 50065
|
||||
# image_type: rgb
|
||||
# fps: 30
|
||||
# width: 1280
|
||||
# height: 720
|
||||
# torso_depth:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: depth
|
||||
# host: 172.17.135.207
|
||||
# port: 50065
|
||||
# image_type: depth
|
||||
# fps: 30
|
||||
# width: 1280
|
||||
# height: 720
|
||||
65
lerobot/configs/robot/reachy2_manipulator.yaml
Normal file
65
lerobot/configs/robot/reachy2_manipulator.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
# [Reachy2 from Pollen Robotics](https://www.pollen-robotics.com)
|
||||
|
||||
# Requires installing extras packages
|
||||
# With pip: `pip install -e ".[reachy2]"`
|
||||
# With poetry: `poetry install --sync --extras "reachy2"`
|
||||
|
||||
|
||||
_target_: lerobot.common.robot_devices.robots.reachy2.ReachyRobotManipulatorconfig
|
||||
robot_type: reachy2
|
||||
|
||||
|
||||
leader_arm:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
calibration_dir: .cache/calibration/so100
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
shoulder_lift: [2, "sts3215"]
|
||||
elbow_flex: [3, "sts3215"]
|
||||
wrist_flex: [4, "sts3215"]
|
||||
wrist_roll: [5, "sts3215"]
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
head_left:
|
||||
_target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
name: teleop
|
||||
host: 172.17.134.85
|
||||
# host: 192.168.0.197
|
||||
# host: localhost
|
||||
port: 50065
|
||||
fps: 30
|
||||
width: 960
|
||||
height: 720
|
||||
image_type: left
|
||||
# head_right:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: teleop
|
||||
# host: 172.17.135.207
|
||||
# port: 50065
|
||||
# image_type: right
|
||||
# fps: 30
|
||||
# width: 960
|
||||
# height: 720
|
||||
# torso_rgb:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: depth
|
||||
# host: 172.17.135.207
|
||||
# # host: localhost
|
||||
# port: 50065
|
||||
# image_type: rgb
|
||||
# fps: 30
|
||||
# width: 1280
|
||||
# height: 720
|
||||
# torso_depth:
|
||||
# _target_: lerobot.common.robot_devices.cameras.reachy2.ReachyCamera
|
||||
# name: depth
|
||||
# host: 172.17.135.207
|
||||
# port: 50065
|
||||
# image_type: depth
|
||||
# fps: 30
|
||||
# width: 1280
|
||||
# height: 720
|
||||
@@ -18,7 +18,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem58760433331
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
|
||||
@@ -29,6 +29,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--num-episodes 1 \
|
||||
--run-compute-stats 0
|
||||
@@ -37,6 +38,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
- Visualize dataset:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode-index 0
|
||||
```
|
||||
@@ -45,6 +47,7 @@ python lerobot/scripts/visualize_dataset.py \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py replay \
|
||||
--fps 30 \
|
||||
--root tmp/data \
|
||||
--repo-id $USER/koch_test \
|
||||
--episode 0
|
||||
```
|
||||
@@ -54,6 +57,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/koch_pick_place_lego \
|
||||
--num-episodes 50 \
|
||||
--warmup-time-s 2 \
|
||||
@@ -68,12 +72,12 @@ python lerobot/scripts/control_robot.py record \
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--resume 1`.
|
||||
If the dataset you want to extend is not on the hub, you also need to add `--local-files-only 1`.
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
To avoid resuming by deleting the dataset, use `--force-override 1`.
|
||||
|
||||
- Train on this dataset with the ACT policy:
|
||||
```bash
|
||||
python lerobot/scripts/train.py \
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
policy=act_koch_real \
|
||||
env=koch_real \
|
||||
dataset_repo_id=$USER/koch_pick_place_lego \
|
||||
@@ -84,6 +88,7 @@ python lerobot/scripts/train.py \
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id $USER/eval_act_koch_real \
|
||||
--num-episodes 10 \
|
||||
--warmup-time-s 2 \
|
||||
@@ -109,7 +114,6 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
reset_follower_position,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
@@ -192,7 +196,6 @@ def record(
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
assign_rewards: bool = False,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
@@ -206,7 +209,6 @@ def record(
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
reset_follower: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
@@ -217,9 +219,6 @@ def record(
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
@@ -248,7 +247,7 @@ def record(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
@@ -260,16 +259,13 @@ def record(
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
@@ -309,11 +305,9 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
if reset_follower:
|
||||
reset_follower_position(robot, initial_position)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
@@ -352,7 +346,7 @@ def replay(
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = False,
|
||||
local_files_only: bool = True,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
@@ -435,7 +429,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--repo-id",
|
||||
@@ -444,10 +438,10 @@ if __name__ == "__main__":
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--local-files-only",
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--warmup-time-s",
|
||||
@@ -480,12 +474,12 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
# parser_record.add_argument(
|
||||
# "--tags",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# help="Add tags to your dataset on the hub.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
@@ -507,12 +501,6 @@ if __name__ == "__main__":
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
@@ -528,18 +516,6 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-follower",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
@@ -549,7 +525,7 @@ if __name__ == "__main__":
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--repo-id",
|
||||
@@ -557,12 +533,6 @@ if __name__ == "__main__":
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,567 +0,0 @@
|
||||
"""
|
||||
Utilities to control a robot in simulation.
|
||||
|
||||
Useful to record a dataset, replay a recorded episode and record an evaluation dataset.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
|
||||
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency.
|
||||
You can modify this value depending on how fast your simulation can run:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--fps 30 \
|
||||
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||
--sim-config lerobot/configs/env/your_sim_config.yaml
|
||||
```
|
||||
|
||||
- Record one episode in order to test replay:
|
||||
```bash
|
||||
python lerobot/scripts/control_sim_robot.py record \
|
||||
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||
--fps 30 \
|
||||
--repo-id $USER/robot_sim_test \
|
||||
--num-episodes 1 \
|
||||
--run-compute-stats 0
|
||||
```
|
||||
|
||||
Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub.
|
||||
|
||||
- Visualize dataset:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset.py \
|
||||
--repo-id $USER/robot_sim_test \
|
||||
--episode-index 0
|
||||
```
|
||||
|
||||
- Replay a sequence of test episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_sim_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||
--fps 30 \
|
||||
--repo-id $USER/robot_sim_test \
|
||||
--episode 0
|
||||
```
|
||||
Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection.
|
||||
|
||||
- Record a full dataset in order to train a policy,
|
||||
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_sim_robot.py record \
|
||||
--robot-path lerobot/configs/robot/your_robot_config.yaml \
|
||||
--sim-config lerobot/configs/env/your_sim_config.yaml \
|
||||
--fps 30 \
|
||||
--repo-id $USER/robot_sim_test \
|
||||
--num-episodes 50 \
|
||||
--episode-time-s 30 \
|
||||
```
|
||||
|
||||
**NOTE**: You can use your keyboard to control data recording flow.
|
||||
- Tap right arrow key '->' to early exit while recording an episode and go to reseting the environment.
|
||||
- Tap right arrow key '->' to early exit while reseting the environment and got to recording the next episode.
|
||||
- Tap left arrow key '<-' to early exit and re-record the current episode.
|
||||
- Tap escape key 'esc' to stop the data recording.
|
||||
This might require a sudo permission to allow your terminal to monitor keyboard events.
|
||||
|
||||
**NOTE**: You can resume/continue data recording by running the same data recording command twice.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
init_keyboard_listener,
|
||||
init_policy,
|
||||
is_headless,
|
||||
log_control_info,
|
||||
predict_action,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"seed": {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"timestamp": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
def none_or_int(value):
|
||||
if value == "None":
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
def init_sim_calibration(robot, cfg):
|
||||
# Constants necessary for transforming the joint pos of the real robot to the sim
|
||||
# depending on the robot discription used in that sim.
|
||||
start_pos = np.array(robot.leader_arms.main.calibration["start_pos"])
|
||||
axis_directions = np.array(cfg.get("axis_directions", [1]))
|
||||
offsets = np.array(cfg.get("offsets", [0])) * np.pi
|
||||
|
||||
return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets}
|
||||
|
||||
|
||||
def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets):
|
||||
"""Counts - starting position -> radians -> align axes -> offset"""
|
||||
return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None):
|
||||
env = env()
|
||||
env.reset()
|
||||
start_teleop_t = time.perf_counter()
|
||||
while True:
|
||||
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||
action = process_action_fn(leader_pos)
|
||||
env.step(np.expand_dims(action, 0))
|
||||
if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s:
|
||||
print("Teleoperation processes finished.")
|
||||
break
|
||||
|
||||
|
||||
def record(
|
||||
env,
|
||||
robot: Robot,
|
||||
process_action_from_leader,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
task: str,
|
||||
fps: int | None = None,
|
||||
tags: list[str] | None = None,
|
||||
pretrained_policy_name_or_path: str = None,
|
||||
policy_overrides: bool | None = None,
|
||||
episode_time_s: int = 30,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = False,
|
||||
play_sounds: bool = True,
|
||||
resume: bool = False,
|
||||
local_files_only: bool = False,
|
||||
run_compute_stats: bool = True,
|
||||
assign_rewards: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
|
||||
if fps is None:
|
||||
fps = policy_fps
|
||||
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
|
||||
|
||||
if policy is None and process_action_from_leader is None:
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# create sim env
|
||||
env = env()
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space])
|
||||
|
||||
# get image keys
|
||||
image_keys = [key for key in env.observation_space if "image" in key]
|
||||
state_keys_dict = env_cfg.state_keys
|
||||
|
||||
if resume:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
dataset.start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * num_cameras,
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
else:
|
||||
features = DEFAULT_FEATURES
|
||||
# add image keys to features
|
||||
for key in image_keys:
|
||||
shape = env.observation_space[key].shape
|
||||
if not key.startswith("observation.image."):
|
||||
key = "observation.image." + key
|
||||
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
|
||||
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"names": None,
|
||||
"shape": env.observation_space[obs_key].shape,
|
||||
}
|
||||
|
||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||
features = {**features, **extra_features}
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
features=features,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * num_cameras,
|
||||
)
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
|
||||
if events is None:
|
||||
events = {"exit_early": False}
|
||||
|
||||
if episode_time_s is None:
|
||||
episode_time_s = float("inf")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
seed = np.random.randint(0, 1e5)
|
||||
observation, info = env.reset(seed=seed)
|
||||
|
||||
while timestamp < episode_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if policy is not None:
|
||||
action = predict_action(observation, policy, device, use_amp)
|
||||
else:
|
||||
leader_pos = robot.leader_arms.main.read("Present_Position")
|
||||
action = process_action_from_leader(leader_pos)
|
||||
|
||||
observation, reward, terminated, _, info = env.step(action)
|
||||
|
||||
success = info.get("is_success", False)
|
||||
env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps)
|
||||
|
||||
frame = {
|
||||
"action": torch.from_numpy(action),
|
||||
"next.reward": reward,
|
||||
"next.success": success,
|
||||
"seed": seed,
|
||||
"timestamp": env_timestamp,
|
||||
}
|
||||
|
||||
# Overwrite environment reward with manually assigned reward
|
||||
if assign_rewards:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
|
||||
# Should success always be false to match what we do in control_utils?
|
||||
frame["next.success"] = False
|
||||
|
||||
for key in image_keys:
|
||||
if not key.startswith("observation.image"):
|
||||
frame["observation.image." + key] = observation[key]
|
||||
else:
|
||||
frame[key] = observation[key]
|
||||
|
||||
for key, obs_key in state_keys_dict.items():
|
||||
frame[key] = torch.from_numpy(observation[obs_key])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"] or terminated:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"] or recorded_episodes >= num_episodes:
|
||||
break
|
||||
else:
|
||||
logging.info("Waiting for a few seconds before starting next episode recording...")
|
||||
busy_wait(3)
|
||||
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(tags=tags)
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
return dataset
|
||||
|
||||
|
||||
def replay(
|
||||
env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True
|
||||
):
|
||||
env = env()
|
||||
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
seeds = dataset.hf_dataset.select_columns("seed")["seed"]
|
||||
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
env.reset(seed=seeds[from_idx].item())
|
||||
logging.info("Replaying episode")
|
||||
log_say("Replaying episode", play_sounds=True)
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
action = items[idx]["action"]
|
||||
env.step(action.unsqueeze(0).numpy())
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True)
|
||||
|
||||
# Set common options for all the subparsers
|
||||
base_parser = argparse.ArgumentParser(add_help=False)
|
||||
base_parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
|
||||
base_parser.add_argument(
|
||||
"--sim-config",
|
||||
help="Path to a yaml config you want to use for initializing a sim environment based on gym ",
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("teleoperate", parents=[base_parser])
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--episode-time-s",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Number of seconds for data recording for each episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
required=True,
|
||||
help="A description of the task preformed during recording that can be used as a language instruction.",
|
||||
)
|
||||
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
|
||||
parser_record.add_argument(
|
||||
"--run-compute-stats",
|
||||
type=int,
|
||||
default=1,
|
||||
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--push-to-hub",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
|
||||
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
|
||||
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
|
||||
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-threads-per-camera",
|
||||
type=int,
|
||||
default=4,
|
||||
help=(
|
||||
"Number of threads writing the frames as png images on disk, per camera. "
|
||||
"Too much threads might cause unstable teleoperation fps due to main thread being blocked. "
|
||||
"Not enough threads might cause low camera fps."
|
||||
),
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--display-cameras",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Visualize image observations with opencv.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot/test",
|
||||
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
init_logging()
|
||||
|
||||
control_mode = args.mode
|
||||
robot_path = args.robot_path
|
||||
env_config_path = args.sim_config
|
||||
kwargs = vars(args)
|
||||
del kwargs["mode"]
|
||||
del kwargs["robot_path"]
|
||||
del kwargs["sim_config"]
|
||||
|
||||
# make gym env
|
||||
env_cfg = init_hydra_config(env_config_path)
|
||||
importlib.import_module(f"gym_{env_cfg.env.name}")
|
||||
|
||||
def env_constructor():
|
||||
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
|
||||
|
||||
robot = None
|
||||
process_leader_actions_fn = None
|
||||
|
||||
if control_mode in ["teleoperate", "record"]:
|
||||
# make robot
|
||||
robot_overrides = ["~cameras", "~follower_arms"]
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
calib_kwgs = init_sim_calibration(robot, env_cfg.calibration)
|
||||
|
||||
def process_leader_actions_fn(action):
|
||||
return real_positions_to_sim(action, **calib_kwgs)
|
||||
|
||||
robot.leader_arms.main.calibration = None
|
||||
|
||||
if control_mode == "teleoperate":
|
||||
teleoperate(env_constructor, robot, process_leader_actions_fn)
|
||||
|
||||
elif control_mode == "record":
|
||||
record(env_constructor, robot, process_leader_actions_fn, **kwargs)
|
||||
|
||||
elif control_mode == "replay":
|
||||
replay(env_constructor, **kwargs)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay."
|
||||
)
|
||||
|
||||
if robot and robot.is_connected:
|
||||
# Disconnect manually to avoid a "Core dump" during process
|
||||
# termination due to camera threads not properly exiting.
|
||||
robot.disconnect()
|
||||
@@ -1,433 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
This script supports performing human interventions during rollouts.
|
||||
Human interventions allow the user to take control of the robot from the policy
|
||||
and correct its behavior. It is specifically designed for reinforcement learning
|
||||
experiments and HIL-SERL (human-in-the-loop reinforcement learning) methods.
|
||||
|
||||
### How to Use
|
||||
|
||||
To rollout a policy on the robot:
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
If you trained a reward classifier on your task, you can also evaluate it using this script.
|
||||
You can annotate the collection with a pre-trained reward classifier by running:
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--pretrained-policy-path-or-name path/to/pretrained_model \
|
||||
--policy-config path/to/policy/config.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position, predict_action
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def rollout(
|
||||
robot: Robot,
|
||||
policy: Policy,
|
||||
reward_classifier,
|
||||
fps: int,
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
device: str = "cpu"
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
This function executes a rollout using the provided policy and robot interface,
|
||||
simulating batched interactions for a fixed control duration.
|
||||
|
||||
The returned dictionary contains rollout statistics, which can be used for analysis and debugging.
|
||||
|
||||
Args:
|
||||
"robot": The robot interface for interacting with the real robot hardware.
|
||||
"policy": The policy to execute. Must be a PyTorch `nn.Module` object.
|
||||
"reward_classifier": A module to classify rewards during the rollout.
|
||||
"fps": The control frequency at which the policy is executed.
|
||||
"control_time_s": The total control duration of the rollout in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": If True, displays camera streams during the rollout.
|
||||
"device": The device to use for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
Dictionary of the statisitcs collected during rollouts.
|
||||
"""
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
if policy is not None:
|
||||
policy.reset()
|
||||
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort() # CG{T}
|
||||
|
||||
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
|
||||
indices_from_policy = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
|
||||
if events["human_intervention_step"]:
|
||||
# take over the robot's actions
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
action = action["action"] # teleop step returns torch tensors but in a dict
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) in placy temporarly for testing purposes
|
||||
if policy is None:
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
indices_from_policy.append(False)
|
||||
else:
|
||||
action = predict_action(observation, policy, device, use_amp)
|
||||
indices_from_policy.append(True)
|
||||
|
||||
robot.send_action(action)
|
||||
observation = robot.capture_observation()
|
||||
|
||||
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to(device))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
|
||||
# TODO send data through the server as soon as you have it
|
||||
|
||||
all_rewards.append(reward)
|
||||
all_actions.append(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
|
||||
reset_follower_position(robot, target_position=init_pos)
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
|
||||
listener.stop()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
fps: float,
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
device: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a policy on a real robot by running multiple episodes and collecting metrics.
|
||||
|
||||
This function executes rollouts of the specified policy on the robot, computes metrics
|
||||
for the rollouts, and optionally evaluates a reward classifier if provided.
|
||||
|
||||
Args:
|
||||
"robot": The robot interface used to interact with the real robot hardware.
|
||||
"policy": The policy to be evaluated. Must be a PyTorch neural network module.
|
||||
"fps": Frames per second (control frequency) for running the policy.
|
||||
"n_episodes": The number of episodes to evaluate the policy.
|
||||
"control_time_s": The max duration for each episode in seconds.
|
||||
"use_amp": Whether to use automatic mixed precision (AMP) for policy evaluation.
|
||||
"display_cameras": Whether to display camera streams during rollouts.
|
||||
"reward_classifier_pretrained_path": Path to the pretrained reward classifier.
|
||||
If provided, the reward classifier will be evaluated during rollouts.
|
||||
"reward_classifier_config_file": Path to the configuration file for the reward classifier.
|
||||
Required if `reward_classifier_pretrained_path` is provided.
|
||||
"device": The device for computations (e.g., "cpu", "cuda" or "mps").
|
||||
|
||||
Returns:
|
||||
"dict": A dictionary containing the following rollout metrics and data:
|
||||
- "metrics": Evaluation metrics such as cumulative rewards, success rates, etc.
|
||||
- "rollout_data": Detailed data from the rollouts, including observations, actions, rewards, and done flags.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
# policy.eval()
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file).to(device)
|
||||
|
||||
device = get_device_from_parameters(policy) if device is None else device
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras, device
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
}
|
||||
for i, (sum_reward, max_reward) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initialize a keyboard listener for controlling the recording and human intervention process.
|
||||
|
||||
Keyboard controls: (Note that this might require sudo permissions to monitor keyboard events)
|
||||
- Right Arrow Key ('->'): Stops the current recording and exits early, useful for ending an episode
|
||||
and moving the next episode recording.
|
||||
- Left Arrow Key ('<-'): Re-records the current episode, allowing the user to start over.
|
||||
- Space Bar: Controls the human intervention process in three steps:
|
||||
1. First press pauses the policy and prompts the user to position the leader similar to the follower.
|
||||
2. Second press initiates human interventions, allowing teleop control of the robot.
|
||||
3. Third press resumes the policy rollout.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["pause_policy"] = False
|
||||
events["human_intervention_step"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
# check if first space press then pause the policy for the user to get ready
|
||||
# if second space press then the user is ready to start intervention
|
||||
if not events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
elif events["pause_policy"] and not events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
elif events["human_intervention_step"]:
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
print("Space key pressed. Human intervention ending, policy resumes control.")
|
||||
log_say("Policy resuming.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(
|
||||
robot,
|
||||
None,
|
||||
fps=40,
|
||||
n_episodes=2,
|
||||
control_time_s=100,
|
||||
display_cameras=args.display_cameras,
|
||||
reward_classifier_config_file=args.reward_classifier_config_file,
|
||||
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
|
||||
)
|
||||
@@ -66,7 +66,7 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str):
|
||||
from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "aloha_hdf5":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
|
||||
elif raw_format in ["rlds", "openx"]:
|
||||
elif "openx_rlds" in raw_format:
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format
|
||||
elif raw_format == "dora_parquet":
|
||||
from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
|
||||
@@ -204,14 +204,24 @@ def push_dataset_to_hub(
|
||||
# convert dataset from original raw format to LeRobot format
|
||||
from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
|
||||
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
|
||||
raw_dir,
|
||||
videos_dir,
|
||||
fps,
|
||||
video,
|
||||
episodes,
|
||||
encoding,
|
||||
)
|
||||
fmt_kwgs = {
|
||||
"raw_dir": raw_dir,
|
||||
"videos_dir": videos_dir,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"episodes": episodes,
|
||||
"encoding": encoding,
|
||||
}
|
||||
|
||||
if "openx_rlds." in raw_format:
|
||||
# Support for official OXE dataset name inside `raw_format`.
|
||||
# For instance, `raw_format="oxe_rlds"` uses the default formating (TODO what does that mean?),
|
||||
# and `raw_format="oxe_rlds.bridge_orig"` uses the brdige_orig formating
|
||||
_, openx_dataset_name = raw_format.split(".")
|
||||
print(f"Converting dataset [{openx_dataset_name}] from 'openx_rlds' to LeRobot format.")
|
||||
fmt_kwgs["openx_dataset_name"] = openx_dataset_name
|
||||
|
||||
hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(**fmt_kwgs)
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
@@ -280,7 +290,7 @@ def main():
|
||||
"--raw-format",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `rlds`, `openx`).",
|
||||
help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`, `openx_rlds`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
|
||||
21
lerobot/scripts/test_reachy.py
Normal file
21
lerobot/scripts/test_reachy.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import time
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
control_mode = "test"
|
||||
robot_path = "lerobot/configs/robot/reachy2.yaml"
|
||||
robot_overrides = None
|
||||
|
||||
robot_cfg = init_hydra_config(robot_path, robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
print(robot.is_connected)
|
||||
# print(robot.get_state())
|
||||
print(robot.capture_observation())
|
||||
time.sleep(5)
|
||||
robot.disconnect()
|
||||
@@ -93,17 +93,6 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
@@ -322,11 +311,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
|
||||
# i.e., pusht
|
||||
if "task_index" in offline_dataset.hf_dataset[0]:
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
|
||||
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
|
||||
@@ -1,320 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
# Check if the device supports AMP
|
||||
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch+1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
@@ -1,586 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import functools
|
||||
from pprint import pformat
|
||||
import random
|
||||
from typing import Optional, Sequence, TypedDict, Callable
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_observation, preprocess_maniskill_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, torch.Tensor] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
||||
logging.info("make_env online")
|
||||
# online_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env eval")
|
||||
# eval_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||
|
||||
# TODO: Add a way to resume training
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# HACK for maniskill
|
||||
# obs = preprocess_observation(obs)
|
||||
obs = preprocess_maniskill_observation(obs)
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
next_obs = preprocess_maniskill_observation(next_obs)
|
||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||
sum_reward_episode += float(reward[0])
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=float(reward[0]),
|
||||
next_state=next_obs,
|
||||
done=done[0],
|
||||
)
|
||||
obs = next_obs
|
||||
|
||||
if interaction_step < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if interaction_step % cfg.training.policy_update_freq == 0:
|
||||
# TD3 Trick
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
if interaction_step % cfg.training.log_freq == 0:
|
||||
logger.log_dict(training_infos, interaction_step, mode="train")
|
||||
|
||||
policy.update_target_networks()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
@@ -207,17 +207,11 @@ def main():
|
||||
required=True,
|
||||
help="Episode to visualize.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for the dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
@@ -274,11 +268,11 @@ def main():
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
# root = kwargs.pop("root")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
@@ -53,82 +53,51 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from io import StringIO
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
from flask import Flask, redirect, render_template, request, url_for
|
||||
import torch
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import IterableNamespace
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config, init_logging
|
||||
from lerobot.scripts.eval import get_pretrained_policy_path
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, episode_index):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset | IterableNamespace | None,
|
||||
episodes: list[int] | None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
template_folder: Path,
|
||||
has_policy=False,
|
||||
):
|
||||
app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve())
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def hommepage(dataset=dataset):
|
||||
if dataset:
|
||||
dataset_namespace, dataset_name = dataset.repo_id.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=0,
|
||||
)
|
||||
)
|
||||
|
||||
dataset_param, episode_param = None, None
|
||||
all_params = request.args
|
||||
if "dataset" in all_params:
|
||||
dataset_param = all_params["dataset"]
|
||||
if "episode" in all_params:
|
||||
episode_param = int(all_params["episode"])
|
||||
|
||||
if dataset_param:
|
||||
dataset_namespace, dataset_name = dataset_param.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=episode_param if episode_param is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
featured_datasets = [
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/columbia_cairlab_pusht_real",
|
||||
"lerobot/taco_play",
|
||||
]
|
||||
return render_template(
|
||||
"visualize_dataset_homepage.html",
|
||||
featured_datasets=featured_datasets,
|
||||
lerobot_datasets=available_datasets,
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
|
||||
def show_first_episode(dataset_namespace, dataset_name):
|
||||
first_episode_id = 0
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
@@ -139,158 +108,166 @@ def run_server(
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
dataset = get_dataset_info(repo_id)
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
major_version = int(match.group(1))
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_frames,
|
||||
"num_episodes": dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes,
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_frames,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.video_path.format(
|
||||
episode_chunk=int(episode_id) // dataset.chunks_size,
|
||||
video_key=video_key,
|
||||
episode_index=episode_id,
|
||||
),
|
||||
"filename": video_key,
|
||||
}
|
||||
for video_key in video_keys
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
)
|
||||
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
episode_data_csv_str=episode_data_csv_str,
|
||||
columns=columns,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=has_policy,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
|
||||
|
||||
def run_inference(
|
||||
dataset, episode_index, policy, policy_method="select_action", num_workers=4, batch_size=32, device="mps"
|
||||
):
|
||||
if policy_method not in ["select_action", "forward"]:
|
||||
raise ValueError(
|
||||
f"`policy_method` is expected to be 'select_action' or 'forward', but '{policy_method}' is provided instead."
|
||||
)
|
||||
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
# When using `select_action`, we set batch size 1 so that we feed 1 frame at a time, in a continuous fashion.
|
||||
batch_size=1 if policy_method == "select_action" else batch_size,
|
||||
sampler=episode_sampler,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
warned_ndim_eq_0 = False
|
||||
warned_ndim_gt_2 = False
|
||||
|
||||
logging.info("Running inference")
|
||||
inference_results = {}
|
||||
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
|
||||
batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
|
||||
with torch.inference_mode():
|
||||
if policy_method == "select_action":
|
||||
gt_action = batch.pop("action")
|
||||
output_dict = {"action": policy.select_action(batch)}
|
||||
batch["action"] = gt_action
|
||||
elif policy_method == "forward":
|
||||
output_dict = policy.forward(batch)
|
||||
# TODO(rcadene): Save and display all predicted actions at a given timestamp
|
||||
# Save predicted action for the next timestamp only
|
||||
output_dict["action"] = output_dict["action"][:, 0, :]
|
||||
|
||||
for key in output_dict:
|
||||
if output_dict[key].ndim == 0:
|
||||
if not warned_ndim_eq_0:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a scalar instead of a vector. It might have been aggregated over the batch dimension (e.g. `loss.mean()`).",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_eq_0 = True
|
||||
continue
|
||||
|
||||
if output_dict[key].ndim > 2:
|
||||
if not warned_ndim_gt_2:
|
||||
warnings.warn(
|
||||
f"Ignore output key '{key}'. Its value is a tensor of {output_dict[key].ndim} dimensions instead of a vector.",
|
||||
stacklevel=1,
|
||||
)
|
||||
warned_ndim_gt_2 = True
|
||||
continue
|
||||
|
||||
if key not in inference_results:
|
||||
inference_results[key] = []
|
||||
inference_results[key].append(output_dict[key].to("cpu"))
|
||||
|
||||
for key in inference_results:
|
||||
inference_results[key] = torch.cat(inference_results[key])
|
||||
|
||||
return inference_results
|
||||
|
||||
|
||||
def get_ep_csv_fname(episode_id: int):
|
||||
ep_csv_fname = f"episode_{episode_id}.csv"
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
||||
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset, policy=None):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
||||
selected_columns.remove("timestamp")
|
||||
if policy is not None:
|
||||
inference_results = run_inference(
|
||||
dataset,
|
||||
episode_index,
|
||||
policy,
|
||||
policy_method="select_action",
|
||||
# num_workers=hydra_cfg.training.num_workers,
|
||||
# batch_size=hydra_cfg.training.batch_size,
|
||||
# device=hydra_cfg.device,
|
||||
)
|
||||
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.features
|
||||
has_action = "action" in dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
if policy is not None:
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"pred_action_{i}" for i in range(dim_action)]
|
||||
|
||||
for column_name in selected_columns:
|
||||
dim_state = (
|
||||
dataset.meta.shapes[column_name][0]
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
header += [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
else:
|
||||
column_names = [f"motor_{i}" for i in range(dim_state)]
|
||||
columns.append({"key": column_name, "value": column_names})
|
||||
rows = []
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
if policy is not None:
|
||||
row += inference_results["action"][i].tolist()
|
||||
rows.append(row)
|
||||
|
||||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
.with_format("pandas")
|
||||
)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
|
||||
rows = np.hstack(
|
||||
(
|
||||
np.expand_dims(data["timestamp"], axis=1),
|
||||
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||
)
|
||||
).tolist()
|
||||
|
||||
# Convert data to CSV string
|
||||
csv_buffer = StringIO()
|
||||
csv_writer = csv.writer(csv_buffer)
|
||||
# Write header
|
||||
csv_writer.writerow(header)
|
||||
# Write data rows
|
||||
csv_writer.writerows(rows)
|
||||
csv_string = csv_buffer.getvalue()
|
||||
|
||||
return csv_string, columns
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
@@ -302,44 +279,25 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
dataset_info["repo_id"] = repo_id
|
||||
return IterableNamespace(dataset_info)
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
dataset: LeRobotDataset | None,
|
||||
episodes: list[int] | None = None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 9090,
|
||||
force_override: bool = False,
|
||||
policy_method: str = "select_action",
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: list[str] | None = None,
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
# Create a temporary directory that will be automatically cleaned up
|
||||
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
|
||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
@@ -350,29 +308,44 @@ def visualize_dataset_html(
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
if dataset is None:
|
||||
if serve:
|
||||
run_server(
|
||||
dataset=None,
|
||||
episodes=None,
|
||||
host=host,
|
||||
port=port,
|
||||
static_folder=static_dir,
|
||||
template_folder=template_dir,
|
||||
)
|
||||
else:
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
pretrained_policy_name_or_path = "aliberts/act_reachy_test_model"
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
logging.info("Loading policy")
|
||||
pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path)
|
||||
|
||||
hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", overrides=["device=mps"])
|
||||
# dataset = make_dataset(hydra_cfg)
|
||||
policy = make_policy(hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path)
|
||||
|
||||
if policy_method == "select_action":
|
||||
# Do not load previous observations or future actions, to simulate that the observations come from
|
||||
# an environment.
|
||||
dataset.delta_timestamps = None
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset, policy=policy)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir, has_policy=policy is not None)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -381,27 +354,15 @@ def main():
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-files-only",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-from-hf-hub",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Load videos and parquet files from HF Hub rather than local system.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
@@ -443,19 +404,9 @@ def main():
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
kwargs.pop("root")
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
visualize_dataset_html(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Interactive Video Background Page</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
if (trimmedValue) {
|
||||
window.location.href = `/${trimmedValue}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<div class="fixed inset-0 w-full h-full overflow-hidden">
|
||||
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
|
||||
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
|
||||
Your browser does not support HTML5 video.
|
||||
</video>
|
||||
</div>
|
||||
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
|
||||
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
|
||||
<div class="text-center mb-8">
|
||||
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
|
||||
|
||||
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
|
||||
|
||||
<p class="text-xl mb-4"></p>
|
||||
<div class="text-left inline-block">
|
||||
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
|
||||
<ul class="list-disc list-inside">
|
||||
{% for dataset in featured_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
Go
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<details class="mt-4 max-w-full px-4">
|
||||
<summary>More example datasets</summary>
|
||||
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
|
||||
{% for dataset in lerobot_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -31,16 +31,11 @@
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
||||
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
|
||||
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
|
||||
</a>
|
||||
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
</a>
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
@@ -98,35 +93,10 @@
|
||||
</div>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="max-w-32 relative text-sm mb-4 select-none"
|
||||
@click.outside="isVideosDropdownOpen = false">
|
||||
<div
|
||||
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
|
||||
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
|
||||
>
|
||||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
<div
|
||||
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
|
||||
class="p-2 cursor-pointer bg-slate-900"
|
||||
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
|
||||
x-text="option"
|
||||
></div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap gap-x-2 gap-y-6">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{% for video_info in videos_info %}
|
||||
<div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
|
||||
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<div x-show="!videoCodecError" class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
@@ -212,12 +182,12 @@
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columns[colIndex].key}`"></p>
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
@@ -227,10 +197,10 @@
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`${rowLabels[rowIndex]}`"></p>
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
@@ -252,20 +222,17 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const parentOrigin = "https://huggingface.co";
|
||||
const searchParams = new URLSearchParams();
|
||||
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
|
||||
searchParams.set("episode", "{{ episode_id }}");
|
||||
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
|
||||
</script>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
hasPolicy: {% if has_policy %}true{% else %}false{% endif %},
|
||||
nColumns: {% if has_policy %}3{% else %}2{% endif %},
|
||||
nStates: 0,
|
||||
nActions: 0,
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
@@ -275,11 +242,6 @@
|
||||
nVideos: {{ videos_info | length }},
|
||||
nVideoReadyToPlay: 0,
|
||||
videoCodecError: false,
|
||||
isVideosDropdownOpen: false,
|
||||
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
|
||||
videosKeysSelected: [],
|
||||
columns: {{ columns | tojson }},
|
||||
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
@@ -289,19 +251,11 @@
|
||||
if(!canPlayVideos){
|
||||
this.videoCodecError = true;
|
||||
}
|
||||
this.videosKeysSelected = this.videosKeys.map(opt => opt)
|
||||
|
||||
// process CSV data
|
||||
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
|
||||
// Create a Blob with the CSV data
|
||||
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
|
||||
// Create a URL for the Blob
|
||||
const csvUrl = URL.createObjectURL(blob);
|
||||
|
||||
// process CSV data
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
@@ -322,17 +276,31 @@
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
||||
this.nActions = seriesNames.length - this.nStates;
|
||||
if(this.hasPolicy){
|
||||
this.nActions = Math.floor(this.nActions / 2);
|
||||
}
|
||||
const colors = [];
|
||||
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
for(const column of this.columns){
|
||||
const nValues = column.value.length;
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
// colors for "state" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
// colors for "action" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
if(this.hasPolicy){
|
||||
// colors for "action" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[2]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightness += 35;
|
||||
}
|
||||
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
@@ -359,18 +327,20 @@
|
||||
return [];
|
||||
}
|
||||
const rows = [];
|
||||
const nRows = Math.max(...this.columns.map(column => column.value.length));
|
||||
const nRows = Math.max(this.nStates, this.nActions);
|
||||
let rowIndex = 0;
|
||||
while(rowIndex < nRows){
|
||||
const row = [];
|
||||
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
|
||||
const nullCell = { isNull: true };
|
||||
const stateValueIdx = rowIndex;
|
||||
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
// row consists of [state value, action value]
|
||||
let idx = rowIndex;
|
||||
for(const column of this.columns){
|
||||
const nColumn = column.value.length;
|
||||
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
|
||||
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
||||
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
||||
if(this.hasPolicy){
|
||||
const predActionValueIdx = stateValueIdx + this.nStates + this.nActions; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN, pred_action1, ..., pred_actionN]
|
||||
row.push(rowIndex < this.nActions ? this.currentFrameData[predActionValueIdx] : nullCell); // push "action value" to row
|
||||
}
|
||||
rowIndex += 1;
|
||||
rows.push(row);
|
||||
|
||||
2921
poetry.lock
generated
2921
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -70,9 +70,8 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
|
||||
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
|
||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
reachy2-sdk = {git = "https://github.com/pollen-robotics/reachy2-sdk", branch="450-opencv-dependency-version", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
transformers = {version = ">=4.47.0", optional = true}
|
||||
torchmetrics = {version = ">=1.6.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -88,7 +87,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
|
||||
hilserl = ["transformers", "torchmetrics"]
|
||||
reachy2 = ["reachy2-sdk"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
|
||||
@@ -14,11 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
@@ -126,14 +124,3 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
@@ -1,244 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.1
|
||||
EPOCH_NUM = 2
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = torch.device("mps")
|
||||
else:
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
multiclass_num_classes = 10
|
||||
epoch = 1
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
||||
|
||||
multiclass_classifier.train()
|
||||
|
||||
logging.info("Start multiclass classifier training")
|
||||
|
||||
# Training loop
|
||||
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = multiclass_classifier(inputs)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
|
||||
epoch += 1
|
||||
|
||||
print("Multiclass classifier training finished")
|
||||
|
||||
multiclass_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = multiclass_classifier(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
_, predicted = torch.max(outputs.logits, 1)
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(predicted.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(testset)
|
||||
|
||||
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
prec = precision(test_predictions, test_labels)
|
||||
rec = recall(test_predictions, test_labels)
|
||||
f1_score = f1(test_predictions, test_labels)
|
||||
auroc_score = auroc(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
def train_evaluate_binary_classifier():
|
||||
logging.info(
|
||||
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
|
||||
target_binary_class = 3
|
||||
|
||||
def one_vs_rest(dataset, target_class):
|
||||
new_targets = []
|
||||
for _, label in dataset:
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
||||
binary_classifier = Classifier(binary_config)
|
||||
|
||||
class_counts = np.bincount(binary_train_dataset.targets)
|
||||
n = len(binary_train_dataset)
|
||||
w0 = n / (2.0 * class_counts[0])
|
||||
w1 = n / (2.0 * class_counts[1])
|
||||
|
||||
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
||||
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
||||
|
||||
binary_classifier.train()
|
||||
|
||||
logging.info("Start binary classifier training")
|
||||
|
||||
# Training loop
|
||||
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(binary_trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
binary_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = binary_classifier(inputs)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
binary_optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
binary_epoch += 1
|
||||
|
||||
logging.info("Binary classifier training finished")
|
||||
logging.info("Start binary classifier evaluation")
|
||||
|
||||
binary_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in binary_testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
outputs = binary_classifier(images)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(outputs.logits.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(binary_test_dataset)
|
||||
|
||||
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
# Calculate metrics
|
||||
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
||||
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_evaluate_multiclass_classifier()
|
||||
train_evaluate_binary_classifier()
|
||||
@@ -1,85 +0,0 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
ClassifierConfig,
|
||||
ClassifierOutput,
|
||||
)
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = ClassifierConfig(num_classes=num_classes)
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig(device="meta")
|
||||
assert config.device == "meta"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("meta")
|
||||
@@ -158,7 +158,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False, local_files_only=True)
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
@@ -295,12 +295,24 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
dataset = record(**record_kwargs)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
# init_dataset_return_value = {}
|
||||
|
||||
# def wrapped_init_dataset(*args, **kwargs):
|
||||
# nonlocal init_dataset_return_value
|
||||
# init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||
# return init_dataset_return_value
|
||||
|
||||
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(**record_kwargs)
|
||||
|
||||
dataset = record(**record_kwargs, resume=True)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
# assert (
|
||||
# init_dataset_return_value["num_episodes"] == 2
|
||||
# ), "`init_dataset` should load the previous episode"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
|
||||
@@ -383,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
include a report on what changed and how that affected the outputs.
|
||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||
add the policies you want to update the test artifacts for.
|
||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
should be updated.
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
|
||||
@@ -5,7 +5,7 @@ we skip them for now in our CI.
|
||||
|
||||
Example to run backward compatiblity tests locally:
|
||||
```
|
||||
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -330,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
train,
|
||||
train_epoch,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.meta = MagicMock()
|
||||
self.meta.stats = {}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
|
||||
def test_create_balanced_sampler():
|
||||
# Mock dataset with imbalanced classes
|
||||
data = [
|
||||
{"label": 0},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
]
|
||||
dataset = MockDataset(data)
|
||||
cfg = MagicMock()
|
||||
cfg.training.label_key = "label"
|
||||
|
||||
sampler = create_balanced_sampler(dataset, cfg)
|
||||
|
||||
# Get weights from the sampler
|
||||
weights = sampler.weights.float()
|
||||
|
||||
# Check that samples have appropriate weights
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
model = make_dummy_model()
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
def test_validate():
|
||||
model = make_dummy_model()
|
||||
|
||||
# Mock components
|
||||
model.eval = MagicMock()
|
||||
val_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call validate
|
||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
||||
|
||||
# Check that model.eval() was called
|
||||
model.eval.assert_called_once()
|
||||
|
||||
# Check accuracy/eval_info are calculated and of the correct type
|
||||
assert isinstance(accuracy, float)
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
def test_train_epoch_multiple_cameras():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image_1": torch.rand(2, 3, 224, 224),
|
||||
"image_2": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image_1", "image_2"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
||||
def test_resume_function(
|
||||
mock_get_model,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
mock_get_last_checkpoint_dir,
|
||||
mock_init_hydra_config,
|
||||
resume,
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
f"output_dir={tempfile.mkdtemp()}",
|
||||
"wandb.enable=False",
|
||||
f"resume={resume}",
|
||||
"dataset_repo_id=dataset_repo_id",
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_keys=[image]",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
"eval.batch_size=2",
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the init_hydra_config function to return cfg
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
||||
|
||||
# Mock logger
|
||||
logger = MagicMock()
|
||||
resumed_step = 1000
|
||||
if resume:
|
||||
logger.load_last_training_state.return_value = resumed_step
|
||||
else:
|
||||
logger.load_last_training_state.return_value = 0
|
||||
mock_logger.return_value = logger
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_get_model.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
# Check that checkpoint handling methods were called
|
||||
if resume:
|
||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_checkpoint_dir.exists.assert_called_once()
|
||||
logger.load_last_training_state.assert_called_once()
|
||||
else:
|
||||
mock_get_last_checkpoint_dir.assert_not_called()
|
||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
||||
mock_checkpoint_dir.exists.assert_not_called()
|
||||
logger.load_last_training_state.assert_not_called()
|
||||
|
||||
# Collect the steps from logger.log_dict calls
|
||||
train_log_calls = logger.log_dict.call_args_list
|
||||
|
||||
# Extract the steps used in the train logging
|
||||
steps = []
|
||||
for call in train_log_calls:
|
||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
||||
if mode == "train":
|
||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
||||
steps.append(step)
|
||||
|
||||
expected_start_step = resumed_step if resume else 0
|
||||
|
||||
# Calculate expected_steps
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
batch_size = cfg.training.batch_size
|
||||
num_batches = (train_size + batch_size - 1) // batch_size
|
||||
|
||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
||||
|
||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
||||
@@ -1,7 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,15 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
Reference in New Issue
Block a user