forked from tangger/lerobot
Compare commits
85 Commits
test/add_c
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9e912a05c | ||
|
|
ff47c0b0d3 | ||
|
|
befa1fe9af | ||
|
|
446f434a8e | ||
|
|
2f3370e42f | ||
|
|
7ae368e983 | ||
|
|
36711d766a | ||
|
|
c9e50bb9b1 | ||
|
|
95de8e273d | ||
|
|
b07d95f0dd | ||
|
|
d9a70376d8 | ||
|
|
0c32008466 | ||
|
|
c462a478c7 | ||
|
|
459f22ed30 | ||
|
|
dc086dc21f | ||
|
|
b9217b06db | ||
|
|
6868c88ef1 | ||
|
|
a1d16fb400 | ||
|
|
a7db3959f5 | ||
|
|
b5f89439ff | ||
|
|
d51374ce12 | ||
|
|
b63738674c | ||
|
|
12525242ce | ||
|
|
7d5a9530f7 | ||
|
|
e0527b4a6b | ||
|
|
efb1982eec | ||
|
|
2211209be5 | ||
|
|
506821c7df | ||
|
|
f1c8bfe01e | ||
|
|
7c89bd1018 | ||
|
|
367dfe51c6 | ||
|
|
e856ffc91e | ||
|
|
9aabe212ea | ||
|
|
42618f4bd6 | ||
|
|
36576c958f | ||
|
|
322a78a378 | ||
|
|
d75b44f89f | ||
|
|
1fb03d4cf2 | ||
|
|
7d2970fdfe | ||
|
|
8105efb338 | ||
|
|
c1d4bf4b63 | ||
|
|
86df8a433d | ||
|
|
956c547254 | ||
|
|
be965019bd | ||
|
|
a0a50de8c9 | ||
|
|
c86dace4c2 | ||
|
|
472a7f58ad | ||
|
|
068efce3f8 | ||
|
|
df7310ea40 | ||
|
|
100f54ee07 | ||
|
|
c2f7af3339 | ||
|
|
a1b5d0faf2 | ||
|
|
d6498150bf | ||
|
|
31c34a4a49 | ||
|
|
b1cfb6a710 | ||
|
|
4a43c83522 | ||
|
|
0a4e9e25d0 | ||
|
|
3bb5ed5e91 | ||
|
|
c5bca1cf0f | ||
|
|
35de91ef2b | ||
|
|
ee306e2f9b | ||
|
|
bae3b02928 | ||
|
|
5b4adc00bb | ||
|
|
22fbc9ea4a | ||
|
|
ca74a13d61 | ||
|
|
18a4598986 | ||
|
|
dc54d357ca | ||
|
|
08ec971086 | ||
|
|
b53d6e0ff2 | ||
|
|
70b652f791 | ||
|
|
7b68bfb73b | ||
|
|
7e0f20fbf2 | ||
|
|
def42ff487 | ||
|
|
c9af8e36a7 | ||
|
|
ed66c92383 | ||
|
|
668d493bf9 | ||
|
|
67f4d7ea7a | ||
|
|
4b0c88ff8e | ||
|
|
b19fef9d18 | ||
|
|
1612e00e63 | ||
|
|
c3bc136420 | ||
|
|
1020bc3108 | ||
|
|
7fcf638c0d | ||
|
|
e35546f58e | ||
|
|
1aa8d4ac91 |
4
.github/workflows/quality.yml
vendored
4
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
|
||||
@@ -21,7 +21,7 @@ How to decode videos?
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
|
||||
18
checkport.py
Normal file
18
checkport.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import socket
|
||||
|
||||
|
||||
def check_port(host, port):
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
try:
|
||||
s.connect((host, port))
|
||||
print(f"Connection successful to {host}:{port}!")
|
||||
except Exception as e:
|
||||
print(f"Connection failed to {host}:{port}: {e}")
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
host = "127.0.0.1" # or "localhost"
|
||||
port = 51350
|
||||
check_port(host, port)
|
||||
@@ -1,25 +1,31 @@
|
||||
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
|
||||
## Source the parts
|
||||
|
||||
## A. Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
## B. Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
@@ -36,23 +42,30 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
## C. Configure the motors
|
||||
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
### 1. Find the USB ports associated to each arm
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find ports
|
||||
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -64,7 +77,6 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -77,13 +89,20 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
@@ -94,7 +113,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -108,23 +127,25 @@ python lerobot/scripts/configure_motor.py \
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
## D. Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
## E. Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
#### a. Manual calibration of follower arm
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
@@ -139,8 +160,8 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
@@ -153,7 +174,7 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
## F. Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
@@ -165,14 +186,14 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```
|
||||
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
#### a. Teleop with displaying cameras
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
## G. Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
@@ -201,7 +222,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
## H. Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
@@ -214,7 +235,7 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
## I. Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
@@ -225,7 +246,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
## J. Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
@@ -248,7 +269,7 @@ Let's explain it:
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
## K. Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
@@ -268,7 +289,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## More
|
||||
## L. More Information
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
|
||||
83
examples/12_train_hilserl_classifier.md
Normal file
83
examples/12_train_hilserl_classifier.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
@@ -74,7 +74,23 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
|
||||
image_transforms = None
|
||||
if cfg.training.image_transforms.enable:
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = OmegaConf.create(
|
||||
{
|
||||
"brightness": {"weight": 0.0, "min_max": None},
|
||||
"contrast": {"weight": 0.0, "min_max": None},
|
||||
"saturation": {"weight": 0.0, "min_max": None},
|
||||
"hue": {"weight": 0.0, "min_max": None},
|
||||
"sharpness": {"weight": 0.0, "min_max": None},
|
||||
"max_num_transforms": None,
|
||||
"random_order": False,
|
||||
"image_size": None,
|
||||
"interpolation": None,
|
||||
"image_mean": None,
|
||||
"image_std": None,
|
||||
}
|
||||
)
|
||||
cfg_tf = OmegaConf.merge(OmegaConf.create(default_tf), cfg.training.image_transforms)
|
||||
|
||||
image_transforms = get_image_transforms(
|
||||
brightness_weight=cfg_tf.brightness.weight,
|
||||
brightness_min_max=cfg_tf.brightness.min_max,
|
||||
@@ -88,6 +104,10 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
sharpness_min_max=cfg_tf.sharpness.min_max,
|
||||
max_num_transforms=cfg_tf.max_num_transforms,
|
||||
random_order=cfg_tf.random_order,
|
||||
image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) if cfg_tf.image_size else None,
|
||||
interpolation=cfg_tf.interpolation,
|
||||
image_mean=cfg_tf.image_mean,
|
||||
image_std=cfg_tf.image_std,
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
|
||||
@@ -84,7 +84,8 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
# Load metadata
|
||||
(self.root / "meta").mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
@@ -291,7 +292,7 @@ class LeRobotDatasetMetadata:
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
@@ -537,7 +538,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
]
|
||||
files += video_files
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
if not self.local_files_only:
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
|
||||
@@ -150,6 +150,10 @@ def get_image_transforms(
|
||||
sharpness_min_max: tuple[float, float] | None = None,
|
||||
max_num_transforms: int | None = None,
|
||||
random_order: bool = False,
|
||||
interpolation: str | None = None,
|
||||
image_size: tuple[int, int] | None = None,
|
||||
image_mean: list[float] | None = None,
|
||||
image_std: list[float] | None = None,
|
||||
):
|
||||
def check_value(name, weight, min_max):
|
||||
if min_max is not None:
|
||||
@@ -170,6 +174,18 @@ def get_image_transforms(
|
||||
|
||||
weights = []
|
||||
transforms = []
|
||||
if image_size is not None:
|
||||
interpolations = [interpolation.value for interpolation in v2.InterpolationMode]
|
||||
if interpolation is None:
|
||||
# Use BICUBIC as default interpolation
|
||||
interpolation_mode = v2.InterpolationMode.BICUBIC
|
||||
elif interpolation in interpolations:
|
||||
interpolation_mode = v2.InterpolationMode(interpolation)
|
||||
else:
|
||||
raise ValueError("The interpolation passed is not supported")
|
||||
# Weight for resizing is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(v2.Resize(size=(image_size[0], image_size[1]), interpolation=interpolation_mode))
|
||||
if brightness_min_max is not None and brightness_weight > 0.0:
|
||||
weights.append(brightness_weight)
|
||||
transforms.append(v2.ColorJitter(brightness=brightness_min_max))
|
||||
@@ -185,6 +201,15 @@ def get_image_transforms(
|
||||
if sharpness_min_max is not None and sharpness_weight > 0.0:
|
||||
weights.append(sharpness_weight)
|
||||
transforms.append(SharpnessJitter(sharpness=sharpness_min_max))
|
||||
if image_mean is not None and image_std is not None:
|
||||
# Weight for normalization is always 1
|
||||
weights.append(1.0)
|
||||
transforms.append(
|
||||
v2.Normalize(
|
||||
mean=image_mean,
|
||||
std=image_std,
|
||||
)
|
||||
)
|
||||
|
||||
n_subset = len(transforms)
|
||||
if max_num_transforms is not None:
|
||||
|
||||
@@ -17,9 +17,11 @@ import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -273,6 +275,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
# TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
@@ -477,7 +480,6 @@ def create_lerobot_dataset_card(
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
|
||||
|
||||
if tags:
|
||||
card_tags += tags
|
||||
@@ -497,8 +499,65 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_path=str(card_template_path),
|
||||
template_str=card_template,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
@@ -159,11 +159,11 @@ DATASETS = {
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
|
||||
@@ -14,9 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from collections import deque
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
# from mani_skill.utils import common
|
||||
|
||||
|
||||
def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
@@ -30,6 +34,10 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
if cfg.env.name == "real_world":
|
||||
return
|
||||
|
||||
if "maniskill" in cfg.env.name:
|
||||
env = make_maniskill_env(cfg, n_envs if n_envs is not None else cfg.eval.batch_size)
|
||||
return env
|
||||
|
||||
package_name = f"gym_{cfg.env.name}"
|
||||
|
||||
try:
|
||||
@@ -56,3 +64,88 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv
|
||||
)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def make_maniskill_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None:
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size),
|
||||
num_envs=n_envs,
|
||||
)
|
||||
# cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True)
|
||||
# state should have the size of 25
|
||||
# env = ConvertToLeRobotEnv(env, n_envs)
|
||||
# env = PixelWrapper(cfg, env, n_envs)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class PixelWrapper(gym.Wrapper):
|
||||
"""
|
||||
Wrapper for pixel observations. Works with Maniskill vectorized environments
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, env, num_envs, num_frames=3):
|
||||
super().__init__(env)
|
||||
self.cfg = cfg
|
||||
self.env = env
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
self._frames = deque([], maxlen=num_frames)
|
||||
self._render_size = cfg.env.render_size
|
||||
|
||||
def _get_obs(self, obs):
|
||||
frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2)
|
||||
self._frames.append(frame)
|
||||
return {"pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to(self.env.device)}
|
||||
|
||||
def reset(self, seed):
|
||||
obs, info = self.env.reset() # (seed=seed)
|
||||
for _ in range(self._frames.maxlen):
|
||||
obs_frames = self._get_obs(obs)
|
||||
return obs_frames, info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ConvertToLeRobotEnv(gym.Wrapper):
|
||||
def __init__(self, env, num_envs):
|
||||
super().__init__(env)
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options={})
|
||||
return self._get_obs(obs), info
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
return self._get_obs(obs), reward, terminated, truncated, info
|
||||
|
||||
def _get_obs(self, observation):
|
||||
sensor_data = observation.pop("sensor_data")
|
||||
del observation["sensor_param"]
|
||||
images = []
|
||||
for cam_data in sensor_data.values():
|
||||
images.append(cam_data["rgb"])
|
||||
|
||||
images = torch.concat(images, axis=-1)
|
||||
# flatten the rest of the data which should just be state data
|
||||
observation = common.flatten_state_dict(observation, use_torch=True, device=self.base_env.device)
|
||||
ret = dict()
|
||||
ret["state"] = observation
|
||||
ret["pixels"] = images
|
||||
return ret
|
||||
|
||||
@@ -28,28 +28,30 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
if "pixels" in observations:
|
||||
if isinstance(observations["pixels"], dict):
|
||||
imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()}
|
||||
else:
|
||||
imgs = {"observation.image": observations["pixels"]}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
for key, img in observations.items():
|
||||
if "images" not in key:
|
||||
continue
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
img = torch.from_numpy(img)
|
||||
if img.ndim == 3:
|
||||
img = img.unsqueeze(0)
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are channel last
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
return_observations[imgkey] = img
|
||||
return_observations[key] = img
|
||||
# obs state agent qpos and qvel
|
||||
# image
|
||||
|
||||
if "environment_state" in observations:
|
||||
return_observations["observation.environment_state"] = torch.from_numpy(
|
||||
@@ -58,5 +60,41 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
# requirement for "agent_pos"
|
||||
return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
# return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float()
|
||||
return_observations["observation.state"] = observations["observation.state"].float()
|
||||
return return_observations
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
@@ -25,6 +25,7 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -107,8 +108,6 @@ class Logger:
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
@@ -128,6 +127,8 @@ class Logger:
|
||||
job_type="train_eval",
|
||||
resume="must" if cfg.resume else None,
|
||||
)
|
||||
# Handle custom step key for rl asynchronous training.
|
||||
self._wandb_custom_step_key: set[str] | None = None
|
||||
print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"]))
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
@@ -173,18 +174,32 @@ class Logger:
|
||||
self,
|
||||
save_dir: Path,
|
||||
train_step: int,
|
||||
optimizer: Optimizer,
|
||||
optimizer: Optimizer | dict,
|
||||
scheduler: LRScheduler | None,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the global training_step, optimizer state, scheduler state, and random state.
|
||||
|
||||
All of these are saved as "training_state.pth" under the checkpoint directory.
|
||||
"""
|
||||
# In Sac, for example, we have a dictionary of torch.optim.Optimizer
|
||||
if type(optimizer) is dict:
|
||||
optimizer_state_dict = {}
|
||||
for k in optimizer:
|
||||
optimizer_state_dict[k] = optimizer[k].state_dict()
|
||||
else:
|
||||
optimizer_state_dict = optimizer.state_dict()
|
||||
|
||||
training_state = {
|
||||
"step": train_step,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"optimizer": optimizer_state_dict,
|
||||
**get_global_random_state(),
|
||||
}
|
||||
# Interaction step is related to the distributed training code
|
||||
# In that setup, we have two kinds of steps, the online step of the env and the optimization step
|
||||
# We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step
|
||||
if interaction_step is not None:
|
||||
training_state["interaction_step"] = interaction_step
|
||||
if scheduler is not None:
|
||||
training_state["scheduler"] = scheduler.state_dict()
|
||||
torch.save(training_state, save_dir / self.training_state_file_name)
|
||||
@@ -196,6 +211,7 @@ class Logger:
|
||||
optimizer: Optimizer,
|
||||
scheduler: LRScheduler | None,
|
||||
identifier: str,
|
||||
interaction_step: int | None = None,
|
||||
):
|
||||
"""Checkpoint the model weights and the training state."""
|
||||
checkpoint_dir = self.checkpoints_dir / str(identifier)
|
||||
@@ -207,16 +223,24 @@ class Logger:
|
||||
self.save_model(
|
||||
checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name
|
||||
)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler)
|
||||
self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler, interaction_step)
|
||||
os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir)
|
||||
|
||||
def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int:
|
||||
def load_last_training_state(self, optimizer: Optimizer | dict, scheduler: LRScheduler | None) -> int:
|
||||
"""
|
||||
Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and
|
||||
random state, and return the global training step.
|
||||
"""
|
||||
training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name)
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
# For the case where the optimizer is a dictionary of optimizers (e.g., sac)
|
||||
if type(training_state["optimizer"]) is dict:
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizer.keys()), (
|
||||
"Optimizer dictionaries do not have the same keys during resume!"
|
||||
)
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizer[k].load_state_dict(v)
|
||||
else:
|
||||
optimizer.load_state_dict(training_state["optimizer"])
|
||||
if scheduler is not None:
|
||||
scheduler.load_state_dict(training_state["scheduler"])
|
||||
elif "scheduler" in training_state:
|
||||
@@ -227,17 +251,44 @@ class Logger:
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"]
|
||||
|
||||
def log_dict(self, d, step, mode="train"):
|
||||
def log_dict(self, d, step: int | None = None, mode="train", custom_step_key: str | None = None):
|
||||
"""Log a dictionary of metrics to WandB."""
|
||||
assert mode in {"train", "eval"}
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if step is None and custom_step_key is None:
|
||||
raise ValueError("Either step or custom_step_key must be provided.")
|
||||
|
||||
if self._wandb is not None:
|
||||
# NOTE: This is not simple. Wandb step is it must always monotonically increase and it
|
||||
# increases with each wandb.log call, but in the case of asynchronous RL for example,
|
||||
# multiple time steps is possible for example, the interaction step with the environment,
|
||||
# the training step, the evaluation step, etc. So we need to define a custom step key
|
||||
# to log the correct step for each metric.
|
||||
if custom_step_key is not None:
|
||||
if self._wandb_custom_step_key is None:
|
||||
self._wandb_custom_step_key = set()
|
||||
new_custom_key = f"{mode}/{custom_step_key}"
|
||||
if new_custom_key not in self._wandb_custom_step_key:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
self._wandb.log({f"{mode}/{k}": v}, step=step)
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
self._wandb.log({f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step})
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
assert mode in {"train", "eval"}
|
||||
|
||||
@@ -66,12 +66,21 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTPolicy
|
||||
|
||||
return VQBeTPolicy, VQBeTConfig
|
||||
elif name == "sac":
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
|
||||
return SACPolicy, SACConfig
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
|
||||
|
||||
def make_policy(
|
||||
hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None
|
||||
hydra_cfg: DictConfig,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
dataset_stats=None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Policy:
|
||||
"""Make an instance of a policy class.
|
||||
|
||||
@@ -85,17 +94,19 @@ def make_policy(
|
||||
be provided when initializing a new policy, and must not be provided when loading a pretrained
|
||||
policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`.
|
||||
"""
|
||||
if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
raise ValueError(
|
||||
"Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
)
|
||||
# if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None):
|
||||
# raise ValueError(
|
||||
# "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided."
|
||||
# )
|
||||
|
||||
policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name)
|
||||
|
||||
policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg)
|
||||
if pretrained_policy_name_or_path is None:
|
||||
# Make a fresh policy.
|
||||
policy = policy_cls(policy_cfg, dataset_stats)
|
||||
# HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments
|
||||
# for example device for the sac policy.
|
||||
policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats)
|
||||
else:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "helper2424/resnet10"
|
||||
device: str = "cpu"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
num_cameras: int = 2
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -0,0 +1,154 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"ClassifierOutput(logits={self.logits}, "
|
||||
f"probabilities={self.probabilities}, "
|
||||
f"hidden_states={self.hidden_states})"
|
||||
)
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
self.encoder = self.encoder.to(self.config.device)
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
self.classifier_head = self.classifier_head.to(self.config.device)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
# processed = self.processor(
|
||||
# images=x, # LeRobotDataset already provides proper tensor format
|
||||
# return_tensors="pt",
|
||||
# )
|
||||
# processed = processed["pixel_values"].to(x.device)
|
||||
processed = x
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_outputs = torch.hstack([self._get_encoder_output(x) for x in xs])
|
||||
logits = self.classifier_head(encoder_outputs)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs)
|
||||
|
||||
def predict_reward(self, x, threshold=0.6):
|
||||
if self.config.num_classes == 2:
|
||||
probs = self.forward(x).probabilities
|
||||
logging.debug(f"Predicted reward images: {probs}")
|
||||
return (probs > threshold).float()
|
||||
else:
|
||||
return torch.argmax(self.forward(x).probabilities, dim=1)
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,17 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
@@ -130,7 +130,7 @@ class Normalize(nn.Module):
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
# TODO(rcadene): should we remove torch.no_grad?
|
||||
@torch.no_grad
|
||||
# @torch.no_grad
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch) # shallow copy avoids mutating the input batch
|
||||
for key, mode in self.modes.items():
|
||||
|
||||
100
lerobot/common/policies/sac/configuration_sac.py
Normal file
100
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [2],
|
||||
}
|
||||
)
|
||||
input_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": "mean_std",
|
||||
"observation.state": "min_max",
|
||||
"observation.environment_state": "min_max",
|
||||
}
|
||||
)
|
||||
input_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": {"mean": [[0.485, 0.456, 0.406]], "std": [[0.229, 0.224, 0.225]]},
|
||||
"observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]},
|
||||
}
|
||||
)
|
||||
output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"})
|
||||
output_normalization_params: dict[str, dict[str, list[float]]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": {"min": [-1, -1], "max": [1, 1]},
|
||||
}
|
||||
)
|
||||
# TODO: Move it outside of the config
|
||||
actor_learner_config: dict[str, str | int] = field(
|
||||
default_factory=lambda: {
|
||||
"actor_ip": "127.0.0.1",
|
||||
"port": 50051,
|
||||
"learner_ip": "127.0.0.1",
|
||||
}
|
||||
)
|
||||
camera_number: int = 1
|
||||
# Add type annotations for these fields:
|
||||
vision_encoder_name: str | None = field(default="helper2424/resnet10")
|
||||
freeze_vision_encoder: bool = True
|
||||
image_encoder_hidden_dim: int = 32
|
||||
shared_encoder: bool = True
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
critic_target_update_weight: float = 0.005
|
||||
utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 256
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
actor_network_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
)
|
||||
policy_kwargs: dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"use_tanh_squash": True,
|
||||
"log_std_min": -5,
|
||||
"log_std_max": 2,
|
||||
"init_final": 0.005,
|
||||
}
|
||||
)
|
||||
764
lerobot/common/policies/sac/modeling_sac.py
Normal file
764
lerobot/common/policies/sac/modeling_sac.py
Normal file
@@ -0,0 +1,764 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
name = "sac"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig | None = None,
|
||||
dataset_stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
input_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.input_normalization_params
|
||||
)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, input_normalization_params
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
|
||||
output_normalization_params = _convert_normalization_params_to_tensor(
|
||||
config.output_normalization_params
|
||||
)
|
||||
|
||||
# HACK: This is hacky and should be removed
|
||||
dataset_stats = dataset_stats or output_normalization_params
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
# NOTE: For images the encoder should be shared between the actor and critic
|
||||
if config.shared_encoder:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_critic = torch.compile(encoder_critic)
|
||||
encoder_actor: SACObservationEncoder = encoder_critic
|
||||
else:
|
||||
encoder_critic = SACObservationEncoder(config, self.normalize_inputs)
|
||||
encoder_actor = SACObservationEncoder(config, self.normalize_inputs)
|
||||
|
||||
self.critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
|
||||
self.critic_target = CriticEnsemble(
|
||||
encoder=encoder_critic,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder_critic.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
output_normalization=self.normalize_targets,
|
||||
)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
self.actor = Policy(
|
||||
encoder=encoder_actor,
|
||||
network=MLP(input_dim=encoder_actor.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
|
||||
# self.actor = torch.compile(self.actor)
|
||||
|
||||
if config.target_entropy is None:
|
||||
config.target_entropy = -np.prod(config.output_shapes["action"][0]) / 2 # (-dim(A)/2)
|
||||
|
||||
# TODO (azouitine): Handle the case where the temparameter is a fixed
|
||||
# TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise
|
||||
# it triggers "can't optimize a non-leaf Tensor"
|
||||
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
||||
self.temperature = self.log_alpha.exp().item()
|
||||
|
||||
def reset(self):
|
||||
"""Reset the policy"""
|
||||
pass
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""Override .to(device) method to involve moving the log_alpha fixed_std"""
|
||||
if self.actor.fixed_std is not None:
|
||||
self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs)
|
||||
# self.log_alpha = self.log_alpha.to(*args, **kwargs)
|
||||
super().to(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
actions, _, _ = self.actor(batch)
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False, image_features: Tensor | None = None
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, image_features=image_features)
|
||||
return q_values
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ...
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=False
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
def compute_loss_critic(self, observations, actions, rewards, next_observations, done, image_features: Tensor | None = None, next_image_features: Tensor | None = None) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_image_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations, actions=next_action_preds, use_target=True, image_features=next_image_features
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_forward(observations, actions, use_target=False)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, image_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, image_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.config.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(self, observations, image_features: Tensor | None = None) -> Tensor:
|
||||
temperature = self.log_alpha.exp().item()
|
||||
|
||||
actions_pi, log_probs, _ = self.actor(observations, image_features)
|
||||
|
||||
q_preds = self.critic_forward(observations, actions_pi, use_target=False)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
hidden_dims: list[int],
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = activate_final
|
||||
layers = []
|
||||
|
||||
# First layer uses input_dim
|
||||
layers.append(nn.Linear(input_dim, hidden_dims[0]))
|
||||
|
||||
# Add activation after first layer
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[0]))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
# Rest of the layers
|
||||
for i in range(1, len(hidden_dims)):
|
||||
layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i]))
|
||||
|
||||
if i + 1 < len(hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(hidden_dims[i]))
|
||||
layers.append(
|
||||
activations if isinstance(activations, nn.Module) else getattr(nn, activations)()
|
||||
)
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CriticEnsemble(nn.Module):
|
||||
"""
|
||||
┌──────────────────┬─────────────────────────────────────────────────────────┐
|
||||
│ Critic Ensemble │ │
|
||||
├──────────────────┘ │
|
||||
│ │
|
||||
│ ┌────┐ ┌────┐ ┌────┐ │
|
||||
│ │ Q1 │ │ Q2 │ │ Qn │ │
|
||||
│ └────┘ └────┘ └────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ │ MLP 1 │ │ MLP 2 │ │ MLP │ │
|
||||
│ │ │ │ │ ... │ num_critics │ │
|
||||
│ │ │ │ │ │ │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ▲ ▲ ▲ │
|
||||
│ └───────────────────┴───────┬────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ ┌───────────────────┐ │
|
||||
│ │ Embedding │ │
|
||||
│ │ │ │
|
||||
│ └───────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ ┌─────────────┴────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ SACObservationEncoder │ │
|
||||
│ │ │ │
|
||||
│ └──────────────────────────┘ │
|
||||
│ ▲ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
│ │ │
|
||||
└───────────────────────────┬────────────────────┬───────────────────────────┘
|
||||
│ Observation │
|
||||
└────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network_list: nn.ModuleList,
|
||||
output_normalization: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network_list = network_list
|
||||
self.init_final = init_final
|
||||
self.output_normalization = output_normalization
|
||||
|
||||
self.parameters_to_optimize = []
|
||||
# Handle the case where a part of the encoder if frozen
|
||||
if self.encoder is not None:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters_to_optimize)
|
||||
|
||||
self.parameters_to_optimize += list(self.network_list.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network_list[0].net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
|
||||
# Output layer
|
||||
self.output_layers = []
|
||||
if init_final is not None:
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
nn.init.uniform_(output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(output_layer.bias, -init_final, init_final)
|
||||
self.output_layers.append(output_layer)
|
||||
else:
|
||||
self.output_layers = []
|
||||
for _ in network_list:
|
||||
output_layer = nn.Linear(out_features, 1)
|
||||
orthogonal_init()(output_layer.weight)
|
||||
self.output_layers.append(output_layer)
|
||||
|
||||
self.output_layers = nn.ModuleList(self.output_layers)
|
||||
self.parameters_to_optimize += list(self.output_layers.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: dict[str, torch.Tensor],
|
||||
actions: torch.Tensor,
|
||||
image_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = get_device_from_parameters(self)
|
||||
# Move each tensor in observations to device
|
||||
observations = {k: v.to(device) for k, v in observations.items()}
|
||||
# NOTE: We normalize actions it helps for sample efficiency
|
||||
actions: dict[str, torch.tensor] = {"action": actions}
|
||||
# NOTE: Normalization layer took dict in input and outputs a dict that why
|
||||
actions = self.output_normalization(actions)["action"]
|
||||
actions = actions.to(device)
|
||||
|
||||
obs_enc = image_features if image_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
list_q_values = []
|
||||
for network, output_layer in zip(self.network_list, self.output_layers, strict=False):
|
||||
x = network(inputs)
|
||||
value = output_layer(x)
|
||||
list_q_values.append(value.squeeze(-1))
|
||||
return torch.stack(list_q_values)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
log_std_min: float = -5,
|
||||
log_std_max: float = 2,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
use_tanh_squash: bool = False,
|
||||
encoder_is_shared: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.log_std_min = log_std_min
|
||||
self.log_std_max = log_std_max
|
||||
self.fixed_std = fixed_std
|
||||
self.use_tanh_squash = use_tanh_squash
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.parameters_to_optimize += list(self.network.parameters())
|
||||
|
||||
if self.encoder is not None and not encoder_is_shared:
|
||||
self.parameters_to_optimize += list(self.encoder.parameters())
|
||||
# Find the last Linear layer's output dimension
|
||||
for layer in reversed(network.net):
|
||||
if isinstance(layer, nn.Linear):
|
||||
out_features = layer.out_features
|
||||
break
|
||||
# Mean layer
|
||||
self.mean_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
self.parameters_to_optimize += list(self.mean_layer.parameters())
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
self.std_layer = nn.Linear(out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
self.parameters_to_optimize += list(self.std_layer.parameters())
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
image_features: torch.Tensor | None = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# Encode observations if encoder exists
|
||||
obs_enc = image_features if image_features is not None else (observations if self.encoder is None else self.encoder(observations))
|
||||
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
log_std = self.std_layer(outputs)
|
||||
assert not torch.isnan(log_std).any(), "[ERROR] log_std became NaN after std_layer!"
|
||||
|
||||
if self.use_tanh_squash:
|
||||
log_std = torch.tanh(log_std)
|
||||
log_std = self.log_std_min + 0.5 * (self.log_std_max - self.log_std_min) * (log_std + 1.0)
|
||||
else:
|
||||
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
|
||||
else:
|
||||
log_std = self.fixed_std.expand_as(means)
|
||||
|
||||
# uses tanh activation function to squash the action to be in the range of [-1, 1]
|
||||
normal = torch.distributions.Normal(means, torch.exp(log_std))
|
||||
x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1))
|
||||
log_probs = normal.log_prob(x_t) # Base log probability before Tanh
|
||||
|
||||
if self.use_tanh_squash:
|
||||
actions = torch.tanh(x_t)
|
||||
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) # Adjust log-probs for Tanh
|
||||
else:
|
||||
actions = x_t # No Tanh; raw Gaussian sample
|
||||
|
||||
log_probs = log_probs.sum(-1) # Sum over action dimensions
|
||||
means = torch.tanh(means) if self.use_tanh_squash else means
|
||||
return actions, log_probs, means
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
device = get_device_from_parameters(self)
|
||||
observations = observations.to(device)
|
||||
if self.encoder is not None:
|
||||
with torch.inference_mode():
|
||||
return self.encoder(observations)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: SACConfig, input_normalizer: nn.Module):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.input_normalization = input_normalizer
|
||||
self.has_pretrained_vision_encoder = False
|
||||
self.parameters_to_optimize = []
|
||||
|
||||
self.aggregation_size: int = 0
|
||||
if any("observation.image" in key for key in config.input_shapes):
|
||||
self.camera_number = config.camera_number
|
||||
|
||||
if self.config.vision_encoder_name is not None:
|
||||
self.image_enc_layers = PretrainedImageEncoder(config)
|
||||
self.has_pretrained_vision_encoder = True
|
||||
else:
|
||||
self.image_enc_layers = DefaultImageEncoder(config)
|
||||
|
||||
self.aggregation_size += config.latent_dim * self.camera_number
|
||||
|
||||
if config.freeze_vision_encoder:
|
||||
freeze_image_encoder(self.image_enc_layers)
|
||||
else:
|
||||
self.parameters_to_optimize += list(self.image_enc_layers.parameters())
|
||||
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.state"][0], out_features=config.latent_dim
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
|
||||
self.parameters_to_optimize += list(self.state_enc_layers.parameters())
|
||||
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
in_features=config.input_shapes["observation.environment_state"][0],
|
||||
out_features=config.latent_dim,
|
||||
),
|
||||
nn.LayerNorm(normalized_shape=config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
self.aggregation_size += config.latent_dim
|
||||
self.parameters_to_optimize += list(self.env_state_enc_layers.parameters())
|
||||
|
||||
self.aggregation_layer = nn.Linear(in_features=self.aggregation_size, out_features=config.latent_dim)
|
||||
self.parameters_to_optimize += list(self.aggregation_layer.parameters())
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
obs_dict = self.input_normalization(obs_dict)
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in obs_dict if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
enc_feat = self.image_enc_layers(obs_dict[image_key])
|
||||
|
||||
# if not self.has_pretrained_vision_encoder:
|
||||
# enc_feat = flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])
|
||||
feat.append(enc_feat)
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
|
||||
features = torch.cat(tensors=feat, dim=-1)
|
||||
features = self.aggregation_layer(features)
|
||||
|
||||
return features
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
"""Returns the dimension of the encoder output"""
|
||||
return self.config.latent_dim
|
||||
|
||||
|
||||
class DefaultImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=config.input_shapes["observation.image"][0],
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=5,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=config.image_encoder_hidden_dim,
|
||||
out_channels=config.image_encoder_hidden_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.image_enc_layers(x)
|
||||
|
||||
|
||||
class PretrainedImageEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config)
|
||||
self.image_enc_proj = nn.Sequential(
|
||||
nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def _load_pretrained_vision_encoder(self, config):
|
||||
"""Set up CNN encoder"""
|
||||
from transformers import AutoModel
|
||||
|
||||
self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True)
|
||||
# self.image_enc_layers.pooler = Identity()
|
||||
|
||||
if hasattr(self.image_enc_layers.config, "hidden_sizes"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension
|
||||
elif hasattr(self.image_enc_layers, "fc"):
|
||||
self.image_enc_out_shape = self.image_enc_layers.fc.in_features
|
||||
else:
|
||||
raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN")
|
||||
return self.image_enc_layers, self.image_enc_out_shape
|
||||
|
||||
def forward(self, x):
|
||||
# TODO: (maractingi, azouitine) check the forward pass of the pretrained model
|
||||
# doesn't reach the classifier layer because we don't need it
|
||||
enc_feat = self.image_enc_layers(x).pooler_output
|
||||
enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1))
|
||||
return enc_feat
|
||||
|
||||
|
||||
def freeze_image_encoder(image_encoder: nn.Module):
|
||||
"""Freeze all parameters in the encoder"""
|
||||
for param in image_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
class Identity(nn.Module):
|
||||
def __init__(self):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
|
||||
# TODO (azouitine): I think in our case this function is not usefull we should remove it
|
||||
# after some investigation
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
|
||||
converted_params = {}
|
||||
for outer_key, inner_dict in normalization_params.items():
|
||||
converted_params[outer_key] = {}
|
||||
for key, value in inner_dict.items():
|
||||
converted_params[outer_key][key] = torch.tensor(value)
|
||||
if "image" in outer_key:
|
||||
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
|
||||
|
||||
return converted_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the SACObservationEncoder
|
||||
import time
|
||||
|
||||
config = SACConfig()
|
||||
config.num_critics = 10
|
||||
encoder = SACObservationEncoder(config)
|
||||
actor_encoder = SACObservationEncoder(config)
|
||||
encoder = torch.compile(encoder)
|
||||
critic_ensemble = CriticEnsemble(
|
||||
encoder=encoder,
|
||||
network_list=nn.ModuleList(
|
||||
[
|
||||
MLP(
|
||||
input_dim=encoder.output_dim + config.output_shapes["action"][0],
|
||||
**config.critic_network_kwargs,
|
||||
)
|
||||
for _ in range(config.num_critics)
|
||||
]
|
||||
),
|
||||
)
|
||||
actor = Policy(
|
||||
encoder=actor_encoder,
|
||||
network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
encoder_is_shared=config.shared_encoder,
|
||||
**config.policy_kwargs,
|
||||
)
|
||||
encoder = encoder.to("cuda:0")
|
||||
critic_ensemble = torch.compile(critic_ensemble)
|
||||
critic_ensemble = critic_ensemble.to("cuda:0")
|
||||
actor = torch.compile(actor)
|
||||
actor = actor.to("cuda:0")
|
||||
obs_dict = {
|
||||
"observation.image": torch.randn(1, 3, 84, 84),
|
||||
"observation.state": torch.randn(1, 4),
|
||||
}
|
||||
actions = torch.randn(1, 2).to("cuda:0")
|
||||
obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()}
|
||||
print("compiling...")
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
action = actor(obs_dict)
|
||||
print("compiled")
|
||||
start = time.perf_counter()
|
||||
for _ in range(1000):
|
||||
# features = encoder(obs_dict)
|
||||
action = actor(obs_dict)
|
||||
# q_value = critic_ensemble(obs_dict, actions)
|
||||
print("Time taken:", time.perf_counter() - start)
|
||||
@@ -11,6 +11,7 @@ from copy import copy
|
||||
from functools import cache
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
@@ -35,7 +36,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
||||
|
||||
def log_dt(shortname, dt_val_s):
|
||||
nonlocal log_items, fps
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)"
|
||||
info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)"
|
||||
if fps is not None:
|
||||
actual_fps = 1 / dt_val_s
|
||||
if actual_fps < fps - 1:
|
||||
@@ -120,14 +121,22 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -152,6 +161,13 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -209,6 +225,7 @@ def record_episode(
|
||||
device,
|
||||
use_amp,
|
||||
fps,
|
||||
record_delta_actions,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
@@ -220,6 +237,7 @@ def record_episode(
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
teleoperate=policy is None,
|
||||
)
|
||||
|
||||
@@ -236,6 +254,7 @@ def control_loop(
|
||||
device=None,
|
||||
use_amp=None,
|
||||
fps=None,
|
||||
record_delta_actions=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
@@ -258,8 +277,12 @@ def control_loop(
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
if record_delta_actions:
|
||||
action["action"] = action["action"] - current_joint_positions
|
||||
else:
|
||||
observation = robot.capture_observation()
|
||||
|
||||
@@ -272,8 +295,14 @@ def control_loop(
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
frame["next.done"] = (events["next.reward"] == 1) or (events["exit_early"])
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if frame["next.done"]:
|
||||
# break
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
@@ -301,6 +330,8 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
@@ -313,6 +344,16 @@ def reset_environment(robot, events, reset_time_s):
|
||||
break
|
||||
|
||||
|
||||
def reset_follower_position(robot: Robot, target_position):
|
||||
current_position = robot.follower_arms["main"].read("Present_Position")
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 30)
|
||||
) # NOTE: 30 is just an aribtrary number
|
||||
for pose in trajectory:
|
||||
robot.send_action(pose)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
robot.disconnect()
|
||||
|
||||
@@ -343,12 +384,16 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None
|
||||
) -> None:
|
||||
features_from_robot = get_features_from_robot(robot, use_videos)
|
||||
if extra_features is not None:
|
||||
features_from_robot.update(extra_features)
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
("features", dataset.features, features_from_robot),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
|
||||
@@ -32,7 +32,7 @@ def ensure_safe_goal_position(
|
||||
safe_goal_pos = present_pos + safe_diff
|
||||
|
||||
if not torch.allclose(goal_pos, safe_goal_pos):
|
||||
logging.warning(
|
||||
logging.debug(
|
||||
"Relative goal position magnitude had to be clamped to be safe.\n"
|
||||
f" requested relative goal position target: {diff}\n"
|
||||
f" clamped relative goal position target: {safe_diff}"
|
||||
@@ -67,6 +67,8 @@ class ManipulatorRobotConfig:
|
||||
# gripper is not put in torque mode.
|
||||
gripper_open_degree: float | None = None
|
||||
|
||||
joint_position_relative_bounds: dict[np.ndarray] | None = None
|
||||
|
||||
def __setattr__(self, prop: str, val):
|
||||
if prop == "max_relative_target" and val is not None and isinstance(val, Sequence):
|
||||
for name in self.follower_arms:
|
||||
@@ -78,6 +80,9 @@ class ManipulatorRobotConfig:
|
||||
"Note: This feature does not yet work with robots where different follower arms have "
|
||||
"different numbers of motors."
|
||||
)
|
||||
if prop == "joint_position_relative_bounds" and val is not None:
|
||||
for key in val:
|
||||
val[key] = torch.tensor(val[key])
|
||||
super().__setattr__(prop, val)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -523,6 +528,14 @@ class ManipulatorRobot:
|
||||
before_fwrite_t = time.perf_counter()
|
||||
goal_pos = leader_pos[name]
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -644,6 +657,14 @@ class ManipulatorRobot:
|
||||
goal_pos = action[from_idx:to_idx]
|
||||
from_idx = to_idx
|
||||
|
||||
# If specified, clip the goal positions within predefined bounds specified in the config of the robot
|
||||
if self.config.joint_position_relative_bounds is not None:
|
||||
goal_pos = torch.clamp(
|
||||
goal_pos,
|
||||
self.config.joint_position_relative_bounds["min"],
|
||||
self.config.joint_position_relative_bounds["max"],
|
||||
)
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
# Slower fps expected due to reading from the follower.
|
||||
if self.config.max_relative_target is not None:
|
||||
@@ -656,6 +677,7 @@ class ManipulatorRobot:
|
||||
|
||||
# Send goal position to each follower
|
||||
goal_pos = goal_pos.numpy().astype(np.int32)
|
||||
|
||||
self.follower_arms[name].write("Goal_Position", goal_pos)
|
||||
|
||||
return torch.cat(action_sent)
|
||||
|
||||
@@ -18,6 +18,7 @@ import os
|
||||
import os.path as osp
|
||||
import platform
|
||||
import random
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
@@ -217,3 +218,28 @@ def log_say(text, play_sounds, blocking=False):
|
||||
|
||||
if play_sounds:
|
||||
say(text, blocking)
|
||||
|
||||
|
||||
class TimerManager:
|
||||
def __init__(self, elapsed_time_list: list[float] | None = None, label="Elapsed time", log=True):
|
||||
self.label = label
|
||||
self.elapsed_time_list = elapsed_time_list
|
||||
self.log = log
|
||||
self.elapsed = 0.0
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.elapsed: float = time.perf_counter() - self.start
|
||||
|
||||
if self.elapsed_time_list is not None:
|
||||
self.elapsed_time_list.append(self.elapsed)
|
||||
|
||||
if self.log:
|
||||
print(f"{self.label}: {self.elapsed:.6f} seconds")
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self):
|
||||
return self.elapsed
|
||||
|
||||
@@ -2,6 +2,7 @@ defaults:
|
||||
- _self_
|
||||
- env: pusht
|
||||
- policy: diffusion
|
||||
- robot: so100
|
||||
|
||||
hydra:
|
||||
run:
|
||||
|
||||
20
lerobot/configs/env/maniskill_example.yaml
vendored
Normal file
20
lerobot/configs/env/maniskill_example.yaml
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 20
|
||||
|
||||
env:
|
||||
name: maniskill/pushcube
|
||||
task: PushCube-v1
|
||||
image_size: 128
|
||||
control_mode: pd_ee_delta_pose
|
||||
state_dim: 25
|
||||
action_dim: 7
|
||||
fps: ${fps}
|
||||
obs: rgb
|
||||
render_mode: rgb_array
|
||||
render_size: 128
|
||||
device: cuda
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: null
|
||||
config_path: null
|
||||
23
lerobot/configs/env/so100_real.yaml
vendored
23
lerobot/configs/env/so100_real.yaml
vendored
@@ -1,6 +1,6 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
fps: 10
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
@@ -8,3 +8,24 @@ env:
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
device: mps
|
||||
|
||||
wrapper:
|
||||
crop_params_dict:
|
||||
observation.images.front: [102, 43, 358, 523]
|
||||
observation.images.side: [92, 123, 379, 349]
|
||||
# observation.images.front: [109, 37, 361, 557]
|
||||
# observation.images.side: [94, 161, 372, 315]
|
||||
resize_size: [128, 128]
|
||||
control_time_s: 20
|
||||
reset_follower_pos: true
|
||||
use_relative_joint_positions: true
|
||||
reset_time_s: 5
|
||||
display_cameras: false
|
||||
delta_action: 0.1
|
||||
joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper
|
||||
|
||||
reward_classifier:
|
||||
pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model
|
||||
config_path: lerobot/configs/policy/hilserl_classifier.yaml
|
||||
|
||||
53
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
53
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
@@ -0,0 +1,53 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized
|
||||
# aractingi/push_cube_square_reward_1_cropped_resized
|
||||
dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized
|
||||
local_files_only: true
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 6
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_keys: ["observation.images.front", "observation.images.side"]
|
||||
label_key: "next.reward"
|
||||
profile_inference_time: false
|
||||
profile_inference_time_iters: 20
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "helper2424/resnet10" # "facebook/convnext-base-224
|
||||
model_type: "cnn"
|
||||
num_cameras: 2 # Has to be len(training.image_keys)
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "outputs/classifier/old_trainer_resnet10_frozen"
|
||||
108
lerobot/configs/policy/sac_maniskill.yaml
Normal file
108
lerobot/configs/policy/sac_maniskill.yaml
Normal file
@@ -0,0 +1,108 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: null
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 10
|
||||
save_freq: 2000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 500
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
# vision_encoder_name: null
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
# freeze_vision_encoder: false
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: [7]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.state:
|
||||
min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01,
|
||||
1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00,
|
||||
-3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00,
|
||||
-6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01,
|
||||
8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01]
|
||||
|
||||
max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400,
|
||||
0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163,
|
||||
7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135,
|
||||
0.4001]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
action:
|
||||
min: [-10.0, -10.0, -10.0, -10.0, -10.0, -10.0, -10.0]
|
||||
max: [10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0]
|
||||
output_normalization_shapes:
|
||||
action: [7]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.80
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
89
lerobot/configs/policy/sac_pusht_keypoints.yaml
Normal file
89
lerobot/configs/policy/sac_pusht_keypoints.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# env=pusht \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: lerobot/pusht_keypoints
|
||||
|
||||
training:
|
||||
offline_steps: 0
|
||||
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
batch_size: 128
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 50000
|
||||
log_freq: 500
|
||||
save_freq: 50000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 40000
|
||||
online_buffer_seed_size: 0
|
||||
do_online_rollout_async: false
|
||||
|
||||
delta_timestamps:
|
||||
observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 5
|
||||
n_action_steps: 5
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.environment_state: [16]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.environment_state: min_max
|
||||
observation.state: min_max
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
# image_encoder_hidden_dim: 32
|
||||
discount: 0.99
|
||||
temperature_init: 1.0
|
||||
num_critics: 2
|
||||
num_subsample_critics: None
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
critic_target_update_weight: 0.005
|
||||
utd_ratio: 2
|
||||
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
127
lerobot/configs/policy/sac_real.yaml
Normal file
127
lerobot/configs/policy/sac_real.yaml
Normal file
@@ -0,0 +1,127 @@
|
||||
# @package _global_
|
||||
|
||||
# Train with:
|
||||
#
|
||||
# python lerobot/scripts/train.py \
|
||||
# +dataset=lerobot/pusht_keypoints
|
||||
# env=pusht \
|
||||
# env.gym.obs_type=environment_state_agent_pos \
|
||||
|
||||
seed: 1
|
||||
dataset_repo_id: aractingi/push_cube_overfit_cropped_resized
|
||||
#aractingi/push_cube_square_offline_demo_cropped_resized
|
||||
|
||||
training:
|
||||
# Offline training dataloader
|
||||
num_workers: 4
|
||||
|
||||
# batch_size: 256
|
||||
batch_size: 512
|
||||
grad_clip_norm: 10.0
|
||||
lr: 3e-4
|
||||
|
||||
eval_freq: 2500
|
||||
log_freq: 1
|
||||
save_freq: 2000000
|
||||
|
||||
online_steps: 1000000
|
||||
online_rollout_n_episodes: 10
|
||||
online_rollout_batch_size: 10
|
||||
online_steps_between_rollouts: 1000
|
||||
online_sampling_ratio: 1.0
|
||||
online_env_seed: 10000
|
||||
online_buffer_capacity: 1000000
|
||||
online_buffer_seed_size: 0
|
||||
online_step_before_learning: 100 #5000
|
||||
do_online_rollout_async: false
|
||||
policy_update_freq: 1
|
||||
|
||||
# delta_timestamps:
|
||||
# observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
|
||||
# action: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
# next.reward: "[i / ${fps} for i in range(${policy.horizon})]"
|
||||
|
||||
policy:
|
||||
name: sac
|
||||
|
||||
pretrained_model_path:
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: 1
|
||||
horizon: 1
|
||||
n_action_steps: 1
|
||||
|
||||
shared_encoder: true
|
||||
vision_encoder_name: "helper2424/resnet10"
|
||||
freeze_vision_encoder: true
|
||||
input_shapes:
|
||||
# # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.state: ["${env.state_dim}"]
|
||||
observation.images.front: [3, 128, 128]
|
||||
observation.images.side: [3, 128, 128]
|
||||
# observation.image: [3, 128, 128]
|
||||
output_shapes:
|
||||
action: [4] # ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.front: mean_std
|
||||
observation.images.side: mean_std
|
||||
observation.state: min_max
|
||||
input_normalization_params:
|
||||
observation.images.front:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.images.side:
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
observation.state:
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
|
||||
# min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274]
|
||||
# max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685]
|
||||
# min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274]
|
||||
# max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792]
|
||||
|
||||
output_normalization_modes:
|
||||
action: min_max
|
||||
output_normalization_params:
|
||||
# action:
|
||||
# min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0]
|
||||
# max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
action:
|
||||
min: [-149.23828125, -97.734375, -100.1953125, -73.740234375]
|
||||
max: [149.23828125, 97.734375, 100.1953125, 73.740234375]
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: 32
|
||||
# discount: 0.99
|
||||
discount: 0.97
|
||||
temperature_init: 1.0
|
||||
num_critics: 2 #10
|
||||
camera_number: 2
|
||||
num_subsample_critics: null
|
||||
critic_lr: 3e-4
|
||||
actor_lr: 3e-4
|
||||
temperature_lr: 3e-4
|
||||
# critic_target_update_weight: 0.005
|
||||
critic_target_update_weight: 0.01
|
||||
utd_ratio: 2 # 10
|
||||
|
||||
actor_learner_config:
|
||||
actor_ip: "127.0.0.1"
|
||||
port: 50051
|
||||
|
||||
# # Loss coefficients.
|
||||
# reward_coeff: 0.5
|
||||
# expectile_weight: 0.9
|
||||
# value_coeff: 0.1
|
||||
# consistency_coeff: 20.0
|
||||
# advantage_scaling: 3.0
|
||||
# pi_coeff: 0.5
|
||||
# temporal_decay_coeff: 0.5
|
||||
# # Target model.
|
||||
# target_model_momentum: 0.995
|
||||
@@ -10,7 +10,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
port: /dev/tty.usbmodem58760430441
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
@@ -23,7 +23,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
port: /dev/tty.usbmodem585A0083391
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
|
||||
@@ -14,11 +14,14 @@ calibration_dir: .cache/calibration/so100
|
||||
# Set this to a positive scalar to have the same value for all motors, or a list that is the same length as
|
||||
# the number of motors in your follower arms.
|
||||
max_relative_target: null
|
||||
joint_position_relative_bounds:
|
||||
max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01]
|
||||
min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786]
|
||||
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
port: /dev/tty.usbmodem58760433331
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
@@ -31,7 +34,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0080971
|
||||
port: /dev/tty.usbmodem58760431631
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
@@ -42,13 +45,13 @@ follower_arms:
|
||||
gripper: [6, "sts3215"]
|
||||
|
||||
cameras:
|
||||
laptop:
|
||||
front:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 0
|
||||
fps: 30
|
||||
width: 640
|
||||
height: 480
|
||||
phone:
|
||||
side:
|
||||
_target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera
|
||||
camera_index: 1
|
||||
fps: 30
|
||||
|
||||
@@ -109,6 +109,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
log_control_info,
|
||||
record_episode,
|
||||
reset_environment,
|
||||
reset_follower_position,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
@@ -191,6 +192,7 @@ def record(
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
assign_rewards: bool = False,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
@@ -204,6 +206,8 @@ def record(
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
reset_follower: bool = False,
|
||||
record_delta_actions: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
@@ -214,6 +218,14 @@ def record(
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{
|
||||
"next.reward": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
if assign_rewards
|
||||
else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
@@ -242,7 +254,7 @@ def record(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
@@ -254,12 +266,15 @@ def record(
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
if reset_follower:
|
||||
initial_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
@@ -293,6 +308,7 @@ def record(
|
||||
device=device,
|
||||
use_amp=use_amp,
|
||||
fps=fps,
|
||||
record_delta_actions=record_delta_actions,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
@@ -300,9 +316,11 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
if reset_follower:
|
||||
reset_follower_position(robot, initial_position)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
@@ -342,21 +360,24 @@ def replay(
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = False,
|
||||
replay_delta_actions: bool = False,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
current_joint_positions = robot.follower_arms["main"].read("Present_Position")
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
if replay_delta_actions:
|
||||
action = action + current_joint_positions
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
@@ -469,12 +490,12 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
# parser_record.add_argument(
|
||||
# "--tags",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# help="Add tags to your dataset on the hub.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
@@ -517,6 +538,24 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--record-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the recording of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--reset-follower",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Resets the follower to the initial position during while reseting the evironment, this is to avoid having the follower start at an awkward position in the next episode",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
@@ -540,6 +579,12 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
|
||||
)
|
||||
parser_replay.add_argument(
|
||||
"--replay-delta-actions",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the replay of delta actions instead of absolute actions.",
|
||||
)
|
||||
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -183,8 +183,14 @@ def record(
|
||||
resume: bool = False,
|
||||
local_files_only: bool = False,
|
||||
run_compute_stats: bool = True,
|
||||
assign_rewards: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# Load pretrained policy
|
||||
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
policy = None
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
@@ -197,7 +203,7 @@ def record(
|
||||
raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.")
|
||||
|
||||
# initialize listener before sim env
|
||||
listener, events = init_keyboard_listener()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# create sim env
|
||||
env = env()
|
||||
@@ -237,6 +243,7 @@ def record(
|
||||
}
|
||||
|
||||
features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None}
|
||||
features = {**features, **extra_features}
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
@@ -288,6 +295,13 @@ def record(
|
||||
"timestamp": env_timestamp,
|
||||
}
|
||||
|
||||
# Overwrite environment reward with manually assigned reward
|
||||
if assign_rewards:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
|
||||
# Should success always be false to match what we do in control_utils?
|
||||
frame["next.success"] = False
|
||||
|
||||
for key in image_keys:
|
||||
if not key.startswith("observation.image"):
|
||||
frame["observation.image." + key] = observation[key]
|
||||
@@ -472,6 +486,13 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
help="Resume recording on an existing dataset.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
|
||||
394
lerobot/scripts/eval_on_robot.py
Normal file
394
lerobot/scripts/eval_on_robot.py
Normal file
@@ -0,0 +1,394 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
Test reward classifier with teleoperation (you need to press space to take over)
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--reward-classifier-pretrained-path outputs/classifier/checkpoints/best/pretrained_model \
|
||||
--reward-classifier-config-file lerobot/configs/policy/hilserl_classifier.yaml \
|
||||
--display-cameras 1
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to("mps")
|
||||
return model
|
||||
|
||||
|
||||
def rollout(
|
||||
robot: Robot,
|
||||
policy: Policy,
|
||||
reward_classifier,
|
||||
fps: int,
|
||||
control_time_s: float = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
# TODO (michel-aractingi): Infer the device from policy parameters when policy is added
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
|
||||
# NOTE: sorting to make sure the key sequence is the same during training and testing.
|
||||
observation = robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
image_keys.sort()
|
||||
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
init_pos = robot.follower_arms["main"].read("Present_Position")
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
|
||||
if events["human_intervention_step"]:
|
||||
# take over the robot's actions
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
action = action["action"] # teleop step returns torch tensors but in a dict
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
# TODO (michel-aractingi) replace this part with policy (predict_action)
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
|
||||
observation = robot.capture_observation()
|
||||
images = []
|
||||
for key in image_keys:
|
||||
if display_cameras:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
images.append(observation[key].to("mps"))
|
||||
|
||||
reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0
|
||||
all_rewards.append(reward)
|
||||
|
||||
# print("REWARD : ", reward)
|
||||
|
||||
all_actions.append(action)
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
|
||||
reset_follower_position(robot, target_position=init_pos)
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
|
||||
listener.stop()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
fps: float,
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
display_cameras: bool = False,
|
||||
reward_classifier_pretrained_path: str | None = None,
|
||||
reward_classifier_config_file: str | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
# policy.eval()
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file)
|
||||
|
||||
for _ in progbar:
|
||||
rollout_data = rollout(
|
||||
robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras
|
||||
)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["pause_policy"] = False
|
||||
events["human_intervention_step"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
# check if first space press then pause the policy for the user to get ready
|
||||
# if second space press then the user is ready to start intervention
|
||||
if not events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(
|
||||
robot,
|
||||
None,
|
||||
fps=40,
|
||||
n_episodes=2,
|
||||
control_time_s=100,
|
||||
display_cameras=args.display_cameras,
|
||||
reward_classifier_config_file=args.reward_classifier_config_file,
|
||||
reward_classifier_pretrained_path=args.reward_classifier_pretrained_path,
|
||||
)
|
||||
375
lerobot/scripts/server/actor_server.py
Normal file
375
lerobot/scripts/server/actor_server.py
Normal file
@@ -0,0 +1,375 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import time
|
||||
from concurrent import futures
|
||||
from statistics import mean, quantiles
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
from threading import Thread
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
# from lerobot.common.envs.utils import preprocess_maniskill_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server.buffer import Transition, move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
parameters_queue = queue.Queue(maxsize=1)
|
||||
message_queue = queue.Queue(maxsize=1_000_000)
|
||||
|
||||
|
||||
class ActorInformation:
|
||||
"""
|
||||
This helper class is used to differentiate between two types of messages that are placed in the same queue during streaming:
|
||||
|
||||
- **Transition Data:** Contains experience tuples (observation, action, reward, next observation) collected during interaction.
|
||||
- **Interaction Messages:** Encapsulates statistics related to the interaction process.
|
||||
|
||||
Attributes:
|
||||
transition (Optional): Transition data to be sent to the learner.
|
||||
interaction_message (Optional): Iteraction message providing additional statistics for logging.
|
||||
"""
|
||||
|
||||
def __init__(self, transition=None, interaction_message=None):
|
||||
self.transition = transition
|
||||
self.interaction_message = interaction_message
|
||||
|
||||
|
||||
class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
"""
|
||||
gRPC service for actor-learner communication in reinforcement learning.
|
||||
|
||||
This service is responsible for:
|
||||
1. Streaming batches of transition data and statistical metrics from the actor to the learner.
|
||||
2. Receiving updated network parameters from the learner.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context): # noqa: N802
|
||||
"""
|
||||
Streams data from the actor to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes them based on their type:
|
||||
|
||||
- **Transition Data:**
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner.
|
||||
|
||||
- **Interaction Messages:**
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
|
||||
Yields:
|
||||
hilserl_pb2.ActorInformation: The response message containing either transition data or an interaction message.
|
||||
"""
|
||||
while True:
|
||||
message = message_queue.get(block=True)
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||
]
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
|
||||
transition_message = hilserl_pb2.Transition(transition_bytes=transition_bytes)
|
||||
|
||||
response = hilserl_pb2.ActorInformation(transition=transition_message)
|
||||
|
||||
elif message.interaction_message is not None:
|
||||
content = hilserl_pb2.InteractionMessage(
|
||||
interaction_message_bytes=pickle.dumps(message.interaction_message)
|
||||
)
|
||||
response = hilserl_pb2.ActorInformation(interaction_message=content)
|
||||
|
||||
yield response
|
||||
|
||||
def SendParameters(self, request, context): # noqa: N802
|
||||
"""
|
||||
Receives updated parameters from the learner and updates the actor.
|
||||
|
||||
The learner calls this method to send new model parameters. The received parameters are deserialized
|
||||
and placed in a queue to be consumed by the actor.
|
||||
|
||||
Args:
|
||||
request (hilserl_pb2.ParameterUpdate): The request containing serialized network parameters.
|
||||
context (grpc.ServicerContext): The gRPC context.
|
||||
|
||||
Returns:
|
||||
hilserl_pb2.Empty: An empty response to acknowledge receipt.
|
||||
"""
|
||||
buffer = io.BytesIO(request.parameter_bytes)
|
||||
params = torch.load(buffer)
|
||||
parameters_queue.put(params)
|
||||
return hilserl_pb2.Empty()
|
||||
|
||||
|
||||
def serve_actor_service(port=50052):
|
||||
"""
|
||||
Runs a gRPC server to start streaming the data from the actor to the learner.
|
||||
Throught this server the learner can push parameters to the Actor as well.
|
||||
"""
|
||||
server = grpc.server(
|
||||
futures.ThreadPoolExecutor(max_workers=20),
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
hilserl_pb2_grpc.add_ActorServiceServicer_to_server(ActorServiceServicer(), server)
|
||||
server.add_insecure_port(f"[::]:{port}")
|
||||
server.start()
|
||||
logging.info(f"[ACTOR] gRPC server listening on port {port}")
|
||||
server.wait_for_termination()
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: queue.Queue, device):
|
||||
if not parameters_queue.empty():
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dict = parameters_queue.get()
|
||||
state_dict = move_state_dict_to_device(state_dict, device=device)
|
||||
policy.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg (DictConfig): Configuration settings for the interaction process.
|
||||
"""
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
# HACK: This is an ugly hack to pass the normalization parameters to the policy
|
||||
# Because the action space is dynamic so we override the output normalization parameters
|
||||
# it's ugly, we know ... and we will fix it
|
||||
min_action_space: list = online_env.action_space.spaces[0].low.tolist()
|
||||
max_action_space: list = online_env.action_space.spaces[0].high.tolist()
|
||||
output_normalization_params: dict[dict[str, list]] = {
|
||||
"action": {"min": min_action_space, "max": max_action_space}
|
||||
}
|
||||
cfg.policy.output_normalization_params = output_normalization_params
|
||||
cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
device=device,
|
||||
)
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
list_policy_time = []
|
||||
episode_intervention = False
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
|
||||
# NOTE: We overide the action if the intervention is True, because the action applied is the intervention action
|
||||
if "is_intervention" in info and info["is_intervention"]:
|
||||
# TODO: Check the shape
|
||||
# NOTE: The action space for demonstration before hand is with the full action space
|
||||
# but sometimes for example we want to deactivate the gripper
|
||||
action = info["action_intervention"]
|
||||
episode_intervention = True
|
||||
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_obs,
|
||||
done=done,
|
||||
complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool
|
||||
)
|
||||
)
|
||||
|
||||
# assign obs to the next obs and continue the rollout
|
||||
obs = next_obs
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
send_transitions_in_chunks(
|
||||
transitions=list_transition_to_send_to_learner, message_queue=message_queue, chunk_size=4
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(list_policy_time)
|
||||
list_policy_time.clear()
|
||||
|
||||
# Send episodic reward to the learner
|
||||
message_queue.put(
|
||||
ActorInformation(
|
||||
interaction_message={
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
obs, info = online_env.reset()
|
||||
|
||||
|
||||
def send_transitions_in_chunks(transitions: list, message_queue, chunk_size: int = 100):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
for i in range(0, len(transitions), chunk_size):
|
||||
chunk = transitions[i : i + chunk_size]
|
||||
logging.debug(f"[ACTOR] Sending chunk of {len(chunk)} transitions to Learner.")
|
||||
message_queue.put(ActorInformation(transition=chunk))
|
||||
|
||||
|
||||
def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
stats = {}
|
||||
list_policy_fps = [1.0 / t for t in list_policy_time]
|
||||
if len(list_policy_fps) > 1:
|
||||
policy_fps = mean(list_policy_fps)
|
||||
quantiles_90 = quantiles(list_policy_fps, n=10)[-1]
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}")
|
||||
stats = {"Policy frequency [Hz]": policy_fps, "Policy frequency 90th-p [Hz]": quantiles_90}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def actor_cli(cfg: dict):
|
||||
robot = make_robot(cfg=cfg.robot)
|
||||
|
||||
server_thread = Thread(target=serve_actor_service, args=(cfg.actor_learner_config.port,), daemon=True)
|
||||
|
||||
# HACK: FOR MANISKILL we do not have a reward classifier
|
||||
# TODO: Remove this once we merge into main
|
||||
reward_classifier = None
|
||||
if (
|
||||
cfg.env.reward_classifier.pretrained_path is not None
|
||||
and cfg.env.reward_classifier.config_path is not None
|
||||
):
|
||||
reward_classifier = get_classifier(
|
||||
pretrained_path=cfg.env.reward_classifier.pretrained_path,
|
||||
config_path=cfg.env.reward_classifier.config_path,
|
||||
)
|
||||
policy_thread = Thread(
|
||||
target=act_with_policy,
|
||||
daemon=True,
|
||||
args=(cfg, robot, reward_classifier),
|
||||
)
|
||||
server_thread.start()
|
||||
policy_thread.start()
|
||||
policy_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
609
lerobot/scripts/server/buffer.py
Normal file
609
lerobot/scripts/server/buffer.py
Normal file
@@ -0,0 +1,609 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import random
|
||||
from typing import Any, Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
import multiprocessing
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, Any] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
|
||||
# Move state tensors to CPU
|
||||
device = torch.device(device)
|
||||
transition["state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda") for key, val in transition["state"].items()
|
||||
}
|
||||
|
||||
# Move action to CPU
|
||||
transition["action"] = transition["action"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
|
||||
# No need to move reward or done, as they are float and bool
|
||||
if isinstance(transition["reward"], torch.Tensor):
|
||||
transition["reward"] = transition["reward"].to(device=device, non_blocking=device.type == "cuda")
|
||||
|
||||
if isinstance(transition["done"], torch.Tensor):
|
||||
transition["done"] = transition["done"].to(device, non_blocking=device.type == "cuda")
|
||||
|
||||
# Move next_state tensors to CPU
|
||||
transition["next_state"] = {
|
||||
key: val.to(device, non_blocking=device.type == "cuda")
|
||||
for key, val in transition["next_state"].items()
|
||||
}
|
||||
|
||||
# If complementary_info is present, move its tensors to CPU
|
||||
# if transition["complementary_info"] is not None:
|
||||
# transition["complementary_info"] = {
|
||||
# key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items()
|
||||
# }
|
||||
return transition
|
||||
|
||||
|
||||
def move_state_dict_to_device(state_dict, device):
|
||||
"""
|
||||
Recursively move all tensors in a (potentially) nested
|
||||
dict/list/tuple structure to the CPU.
|
||||
"""
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
return state_dict.to(device)
|
||||
elif isinstance(state_dict, dict):
|
||||
return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()}
|
||||
elif isinstance(state_dict, list):
|
||||
return [move_state_dict_to_device(v, device=device) for v in state_dict]
|
||||
elif isinstance(state_dict, tuple):
|
||||
return tuple(move_state_dict_to_device(v, device=device) for v in state_dict)
|
||||
else:
|
||||
return state_dict
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape # noqa: N806
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[list[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
storage_device: str = "cpu",
|
||||
use_shared_memory: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored when adding transitions to the buffer.
|
||||
Using "cpu" can help save GPU memory.
|
||||
use_shared_memory (bool): Whether to use shared memory for the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.storage_device = storage_device
|
||||
self.memory: list[Transition] = torch.multiprocessing.Manager().list() if use_shared_memory else []
|
||||
self.position = 0
|
||||
|
||||
# Convert state_keys to a list for pickling
|
||||
self.state_keys = list(state_keys) if state_keys is not None else []
|
||||
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def __len__(self):
|
||||
return len(self.memory)
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Move tensors to the storage device
|
||||
state = {key: tensor.to(self.storage_device) for key, tensor in state.items()}
|
||||
next_state = {key: tensor.to(self.storage_device) for key, tensor in next_state.items()}
|
||||
action = action.to(self.storage_device)
|
||||
# if complementary_info is not None:
|
||||
# complementary_info = {
|
||||
# key: tensor.to(self.storage_device) for key, tensor in complementary_info.items()
|
||||
# }
|
||||
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append({}) # Need to append something first for Manager().list()
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
capacity: Optional[int] = None,
|
||||
action_mask: Optional[Sequence[int]] = None,
|
||||
action_delta: Optional[float] = None,
|
||||
use_shared_memory: bool = False,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
if capacity is None:
|
||||
capacity = len(lerobot_dataset)
|
||||
|
||||
if capacity < len(lerobot_dataset):
|
||||
raise ValueError(
|
||||
"The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset."
|
||||
)
|
||||
|
||||
replay_buffer = cls(capacity=capacity, device=device, state_keys=state_keys, use_shared_memory=use_shared_memory)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(device)
|
||||
|
||||
if action_mask is not None:
|
||||
if data["action"].dim() == 1:
|
||||
data["action"] = data["action"][action_mask]
|
||||
else:
|
||||
data["action"] = data["action"][:, action_mask]
|
||||
|
||||
if action_delta is not None:
|
||||
data["action"] = data["action"] / action_delta
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
batch_size = min(batch_size, len(self.memory))
|
||||
# Different sampling approach for shared memory list vs regular list
|
||||
|
||||
list_of_transitions = random.sample(list(self.memory), batch_size)
|
||||
# if isinstance(self.memory, multiprocessing.managers.ListProxy):
|
||||
# # For shared memory list, we need to be careful about thread safety
|
||||
# with torch.multiprocessing.Lock():
|
||||
# # Get indices first to minimize lock time
|
||||
# indices = torch.randint(len(self.memory), size=(batch_size,)).tolist()
|
||||
# # Convert to list to avoid multiple proxy accesses
|
||||
# list_of_transitions = [self.memory[i] for i in indices]
|
||||
# else:
|
||||
# # For regular list, use faster random.sample
|
||||
# list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
def to_lerobot_dataset(
|
||||
self,
|
||||
repo_id: str,
|
||||
fps=1, # If you have real timestamps, adjust this
|
||||
root=None,
|
||||
task_name="from_replay_buffer",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts all transitions in this ReplayBuffer into a single LeRobotDataset object,
|
||||
splitting episodes by transitions where 'done=True'.
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: The resulting offline dataset.
|
||||
"""
|
||||
if len(self.memory) == 0:
|
||||
raise ValueError("The replay buffer is empty. Cannot convert to a dataset.")
|
||||
|
||||
# Infer the shapes and dtypes of your features
|
||||
# We'll create a features dict that is suitable for LeRobotDataset
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# First, grab one transition to inspect shapes
|
||||
first_transition = self.memory[0]
|
||||
|
||||
# We'll store default metadata for every episode: indexes, timestamps, etc.
|
||||
features = {
|
||||
"index": {"dtype": "int64", "shape": [1]}, # global index across episodes
|
||||
"episode_index": {"dtype": "int64", "shape": [1]}, # which episode
|
||||
"frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode
|
||||
"timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy
|
||||
"task_index": {"dtype": "int64", "shape": [1]},
|
||||
}
|
||||
|
||||
# Add "action"
|
||||
act_info = guess_feature_info(
|
||||
first_transition["action"].squeeze(dim=0), "action"
|
||||
) # Remove batch dimension
|
||||
features["action"] = act_info
|
||||
|
||||
# Add "reward" (scalars)
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,)}
|
||||
|
||||
# Add "done" (boolean scalars)
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,)}
|
||||
|
||||
# Add state keys
|
||||
for key in self.state_keys:
|
||||
sample_val = first_transition["state"][key].squeeze(dim=0) # Remove batch dimension
|
||||
if not isinstance(sample_val, torch.Tensor):
|
||||
raise ValueError(
|
||||
f"State key '{key}' is not a torch.Tensor. Please ensure your states are stored as torch.Tensors."
|
||||
)
|
||||
f_info = guess_feature_info(sample_val, key)
|
||||
features[key] = f_info
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Create an empty LeRobotDataset
|
||||
# We'll store all frames as separate images only if we detect shape = (3, H, W) or (1, H, W).
|
||||
# By default we won't do videos, but feel free to adapt if you have them.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
lerobot_dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=fps, # If you have real timestamps, adjust this
|
||||
root=root, # Or some local path where you'd like the dataset files to go
|
||||
robot=None,
|
||||
robot_type=None,
|
||||
features=features,
|
||||
use_videos=True, # We won't do actual video encoding for a replay buffer
|
||||
)
|
||||
|
||||
# Start writing images if needed. If you have no image features, this is harmless.
|
||||
# Set num_processes or num_threads if you want concurrency.
|
||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=2)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Convert transitions into episodes and frames
|
||||
# We detect episode boundaries by `done == True`.
|
||||
# --------------------------------------------------------------------------------------------
|
||||
episode_index = 0
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
frame_idx_in_episode = 0
|
||||
for global_frame_idx, transition in enumerate(self.memory):
|
||||
frame_dict = {}
|
||||
|
||||
# Fill the data for state keys
|
||||
for key in self.state_keys:
|
||||
# Expand dimension to match what the dataset expects (the dataset wants the raw shape)
|
||||
# We assume your buffer has shape [C, H, W] (if image) or [D] if vector
|
||||
# This is typically already correct, but if needed you can reshape below.
|
||||
frame_dict[key] = transition["state"][key].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
|
||||
# Fill action, reward, done
|
||||
# Make sure they are shape (X,) or (X,Y,...) as needed.
|
||||
frame_dict["action"] = transition["action"].cpu().squeeze(dim=0) # Remove batch dimension
|
||||
frame_dict["next.reward"] = (
|
||||
torch.tensor([transition["reward"]], dtype=torch.float32).cpu().squeeze(dim=0)
|
||||
)
|
||||
frame_dict["next.done"] = (
|
||||
torch.tensor([transition["done"]], dtype=torch.bool).cpu().squeeze(dim=0)
|
||||
)
|
||||
# Add to the dataset's buffer
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# Move to next frame
|
||||
frame_idx_in_episode += 1
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if transition["done"]:
|
||||
# Use some placeholder name for the task
|
||||
lerobot_dataset.save_episode(task="from_replay_buffer")
|
||||
episode_index += 1
|
||||
frame_idx_in_episode = 0
|
||||
# Start a new buffer for the next episode
|
||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index)
|
||||
|
||||
# We are done adding frames
|
||||
# If the last transition wasn't done=True, we still have an open buffer with frames.
|
||||
# We'll consider that an incomplete episode and still save it:
|
||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||
lerobot_dataset.save_episode(task=task_name)
|
||||
|
||||
lerobot_dataset.stop_image_writer()
|
||||
|
||||
lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t: torch.Tensor, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or array.
|
||||
If it looks like a 3D (C,H,W) shape, we might consider it an 'image'.
|
||||
Otherwise default to 'float32' for numeric. You can customize as needed.
|
||||
"""
|
||||
shape = tuple(t.shape)
|
||||
# Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image'
|
||||
if len(shape) == 3 and shape[0] in [1, 3]:
|
||||
return {
|
||||
"dtype": "image",
|
||||
"shape": shape,
|
||||
}
|
||||
else:
|
||||
# Otherwise treat as numeric
|
||||
return {
|
||||
"dtype": "float32",
|
||||
"shape": shape,
|
||||
}
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# dataset_name = "aractingi/push_green_cube_hf_cropped_resized"
|
||||
# dataset = LeRobotDataset(repo_id=dataset_name)
|
||||
|
||||
# replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
# lerobot_dataset=dataset, state_keys=["observation.image", "observation.state"]
|
||||
# )
|
||||
# replay_buffer_converted = replay_buffer.to_lerobot_dataset(repo_id="AdilZtn/pusht_image_converted")
|
||||
# for i in range(len(replay_buffer_converted)):
|
||||
# replay_convert = replay_buffer_converted[i]
|
||||
# dataset_convert = dataset[i]
|
||||
# for key in replay_convert.keys():
|
||||
# if key in {"index", "episode_index", "frame_index", "timestamp", "task_index"}:
|
||||
# continue
|
||||
# if key in dataset_convert.keys():
|
||||
# assert torch.equal(replay_convert[key], dataset_convert[key])
|
||||
# print(f"Key {key} is equal : {replay_convert[key].size()}, {dataset_convert[key].size()}")
|
||||
# re_reconverted_dataset = ReplayBuffer.from_lerobot_dataset(
|
||||
# replay_buffer_converted, state_keys=["observation.image", "observation.state"], device="cpu"
|
||||
# )
|
||||
# for _ in range(20):
|
||||
# batch = re_reconverted_dataset.sample(32)
|
||||
|
||||
# for key in batch.keys():
|
||||
# if key in {"state", "next_state"}:
|
||||
# for key_state in batch[key].keys():
|
||||
# print(key_state, batch[key][key_state].size())
|
||||
# continue
|
||||
# print(key, batch[key].size())
|
||||
287
lerobot/scripts/server/crop_dataset_roi.py
Normal file
287
lerobot/scripts/server/crop_dataset_roi.py
Normal file
@@ -0,0 +1,287 @@
|
||||
import argparse # noqa: I001
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Tuple
|
||||
from pathlib import Path
|
||||
import cv2
|
||||
|
||||
# import torch.nn.functional as F # noqa: N812
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
ix, iy = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal ix, iy, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
ix, iy = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(iy, y)
|
||||
left = min(ix, x)
|
||||
bottom = max(iy, y)
|
||||
right = max(ix, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(iy, y)
|
||||
left = min(ix, x)
|
||||
bottom = max(iy, y)
|
||||
right = max(ix, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: Dict[str, Tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: Tuple[int, int] = (128, 128),
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = list(resize_size)
|
||||
|
||||
# 2. Process each episode in the original dataset.
|
||||
episodes_info = original_dataset.meta.episodes
|
||||
# (Sort episodes by episode_index for consistency.)
|
||||
|
||||
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
|
||||
# Use the first task from the episode metadata (or "unknown" if not provided)
|
||||
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
|
||||
|
||||
last_episode_index = 0
|
||||
for sample in tqdm(original_dataset):
|
||||
episode_index = sample.pop("episode_index")
|
||||
if episode_index != last_episode_index:
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
last_episode_index = episode_index
|
||||
sample.pop("frame_index")
|
||||
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
|
||||
new_sample = sample.copy()
|
||||
# Loop over each observation key that should be cropped/resized.
|
||||
for key, params in crop_params_dict.items():
|
||||
if key in new_sample:
|
||||
top, left, height, width = params
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(new_sample[key], top, left, height, width)
|
||||
resized = F.resize(cropped, resize_size)
|
||||
new_sample[key] = resized
|
||||
# Add the transformed frame to the new dataset.
|
||||
new_dataset.add_frame(new_sample)
|
||||
|
||||
# save last episode
|
||||
new_dataset.save_episode(task, encode_videos=True)
|
||||
|
||||
# Optionally, consolidate the new dataset to compute statistics and update video info.
|
||||
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
|
||||
|
||||
new_dataset.push_to_hub(tags=None)
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
local_files_only = args.root is not None
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root, local_files_only=local_files_only)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path, "r") as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# rois = {
|
||||
# "observation.images.front": [102, 43, 358, 523],
|
||||
# "observation.images.side": [92, 123, 379, 349],
|
||||
# }
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
65
lerobot/scripts/server/find_joint_limits.py
Normal file
65
lerobot/scripts/server/find_joint_limits.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from lerobot.common.robot_devices.control_utils import is_headless
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
|
||||
def find_joint_bounds(
|
||||
robot,
|
||||
control_time_s=20,
|
||||
display_cameras=False,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
control_time_s = float("inf")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
pos_list = []
|
||||
while timestamp < control_time_s:
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
|
||||
pos_list.append(robot.follower_arms["main"].read("Present_Position"))
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if timestamp > 60:
|
||||
max = np.max(np.stack(pos_list), 0)
|
||||
min = np.min(np.stack(pos_list), 0)
|
||||
print(f"Max angle position per joint {max}")
|
||||
print(f"Min angle position per joint {min}")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
args = parser.parse_args()
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
|
||||
robot = make_robot(robot_cfg)
|
||||
find_joint_bounds(robot, control_time_s=args.control_time_s)
|
||||
857
lerobot/scripts/server/gym_manipulator.py
Normal file
857
lerobot/scripts/server/gym_manipulator.py
Normal file
@@ -0,0 +1,857 @@
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from threading import Lock
|
||||
from typing import Annotated, Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.utils.utils import init_hydra_config, log_say
|
||||
from lerobot.scripts.server.maniskill_manipulator import make_maniskill
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class HILSerlRobotEnv(gym.Env):
|
||||
"""
|
||||
Gym-compatible environment for evaluating robotic control policies with integrated human intervention.
|
||||
|
||||
This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta)
|
||||
and absolute joint position commands and automatically configures its observation and action spaces based on the robot's
|
||||
sensors and configuration.
|
||||
|
||||
The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during
|
||||
each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag
|
||||
`is_intervention`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
robot,
|
||||
use_delta_action_space: bool = True,
|
||||
delta: float | None = None,
|
||||
display_cameras: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize the HILSerlRobotEnv environment.
|
||||
|
||||
The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup
|
||||
supports both relative (delta) adjustments and absolute joint positions for controlling the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot interface object used to connect and interact with the physical robot.
|
||||
use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute
|
||||
joint positions are used.
|
||||
delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between
|
||||
0 and 1 when using a delta action space.
|
||||
display_cameras (bool): If True, the robot's camera feeds will be displayed during execution.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.robot = robot
|
||||
self.display_cameras = display_cameras
|
||||
|
||||
# Connect to the robot if not already connected.
|
||||
if not self.robot.is_connected:
|
||||
self.robot.connect()
|
||||
|
||||
self.initial_follower_position = robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Episode tracking.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self.delta = delta
|
||||
self.use_delta_action_space = use_delta_action_space
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
|
||||
# Retrieve the size of the joint position interval bound.
|
||||
self.relative_bounds_size = (
|
||||
self.robot.config.joint_position_relative_bounds["max"]
|
||||
- self.robot.config.joint_position_relative_bounds["min"]
|
||||
)
|
||||
|
||||
self.delta_relative_bounds_size = self.relative_bounds_size * self.delta
|
||||
|
||||
self.robot.config.max_relative_target = self.delta_relative_bounds_size.float()
|
||||
|
||||
# Dynamically configure the observation and action spaces.
|
||||
self._setup_spaces()
|
||||
|
||||
def _setup_spaces(self):
|
||||
"""
|
||||
Dynamically configure the observation and action spaces based on the robot's capabilities.
|
||||
|
||||
Observation Space:
|
||||
- For keys with "image": A Box space with pixel values ranging from 0 to 255.
|
||||
- For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range.
|
||||
|
||||
Action Space:
|
||||
- The action space is defined as a Tuple where:
|
||||
• The first element is a Box space representing joint position commands. It is defined as relative (delta)
|
||||
or absolute, based on the configuration.
|
||||
• The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation).
|
||||
"""
|
||||
example_obs = self.robot.capture_observation()
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
image_keys = [key for key in example_obs if "image" in key]
|
||||
state_keys = [key for key in example_obs if "image" not in key]
|
||||
observation_spaces = {
|
||||
key: gym.spaces.Box(low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8)
|
||||
for key in image_keys
|
||||
}
|
||||
observation_spaces["observation.state"] = gym.spaces.Dict(
|
||||
{
|
||||
key: gym.spaces.Box(low=0, high=10, shape=example_obs[key].shape, dtype=np.float32)
|
||||
for key in state_keys
|
||||
}
|
||||
)
|
||||
|
||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||
|
||||
# Define the action space for joint positions along with setting an intervention flag.
|
||||
action_dim = len(self.robot.follower_arms["main"].read("Present_Position"))
|
||||
if self.use_delta_action_space:
|
||||
action_space_robot = gym.spaces.Box(
|
||||
low=-self.relative_bounds_size.cpu().numpy(),
|
||||
high=self.relative_bounds_size.cpu().numpy(),
|
||||
shape=(action_dim,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
else:
|
||||
action_space_robot = gym.spaces.Box(
|
||||
low=self.robot.config.joint_position_relative_bounds["min"].cpu().numpy(),
|
||||
high=self.robot.config.joint_position_relative_bounds["max"].cpu().numpy(),
|
||||
shape=(action_dim,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
self.action_space = gym.spaces.Tuple(
|
||||
(
|
||||
action_space_robot,
|
||||
gym.spaces.Discrete(2),
|
||||
),
|
||||
)
|
||||
|
||||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment to its initial state.
|
||||
This method resets the step counter and clears any episodic data.
|
||||
|
||||
Args:
|
||||
seed (Optional[int]): A seed for random number generation to ensure reproducibility.
|
||||
options (Optional[dict]): Additional options to influence the reset behavior.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- observation (dict): The initial sensor observation.
|
||||
- info (dict): A dictionary with supplementary information, including the key "initial_position".
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
# Capture the initial observation.
|
||||
observation = self.robot.capture_observation()
|
||||
|
||||
# Reset episode tracking variables.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
return observation, {"initial_position": self.initial_follower_position}
|
||||
|
||||
def step(
|
||||
self, action: Tuple[np.ndarray, bool]
|
||||
) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]:
|
||||
"""
|
||||
Execute a single step within the environment using the specified action.
|
||||
|
||||
The provided action is a tuple comprised of:
|
||||
• A policy action (joint position commands) that may be either in absolute values or as a delta.
|
||||
• A boolean flag indicating whether teleoperation (human intervention) should be used for this step.
|
||||
|
||||
Behavior:
|
||||
- When the intervention flag is False, the environment processes and sends the policy action to the robot.
|
||||
- When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted
|
||||
to relative change based on the current joint positions.
|
||||
|
||||
Args:
|
||||
action (tuple): A tuple with two elements:
|
||||
- policy_action (np.ndarray or torch.Tensor): The commanded joint positions.
|
||||
- intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing:
|
||||
- observation (dict): The new sensor observation after taking the step.
|
||||
- reward (float): The step reward (default is 0.0 within this wrapper).
|
||||
- terminated (bool): True if the episode has reached a terminal state.
|
||||
- truncated (bool): True if the episode was truncated (e.g., time constraints).
|
||||
- info (dict): Additional debugging information including:
|
||||
◦ "action_intervention": The teleop action if intervention was used.
|
||||
◦ "is_intervention": Flag indicating whether teleoperation was employed.
|
||||
"""
|
||||
policy_action, intervention_bool = action
|
||||
teleop_action = None
|
||||
self.current_joint_positions = self.robot.follower_arms["main"].read("Present_Position")
|
||||
if isinstance(policy_action, torch.Tensor):
|
||||
policy_action = policy_action.cpu().numpy()
|
||||
policy_action = np.clip(policy_action, self.action_space[0].low, self.action_space[0].high)
|
||||
if not intervention_bool:
|
||||
if self.use_delta_action_space:
|
||||
target_joint_positions = self.current_joint_positions + self.delta * policy_action
|
||||
else:
|
||||
target_joint_positions = policy_action
|
||||
self.robot.send_action(torch.from_numpy(target_joint_positions))
|
||||
observation = self.robot.capture_observation()
|
||||
else:
|
||||
observation, teleop_action = self.robot.teleop_step(record_data=True)
|
||||
teleop_action = teleop_action["action"] # Convert tensor to appropriate format
|
||||
|
||||
# When applying the delta action space, convert teleop absolute values to relative differences.
|
||||
if self.use_delta_action_space:
|
||||
teleop_action = (teleop_action - self.current_joint_positions) / self.delta
|
||||
if torch.any(teleop_action < -self.relative_bounds_size) and torch.any(
|
||||
teleop_action > self.relative_bounds_size
|
||||
):
|
||||
logging.debug(
|
||||
f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n"
|
||||
f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n"
|
||||
f"upper bounds condition {teleop_action > self.relative_bounds_size}"
|
||||
)
|
||||
|
||||
teleop_action = torch.clamp(
|
||||
teleop_action, -self.relative_bounds_size, self.relative_bounds_size
|
||||
)
|
||||
# NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action.
|
||||
if teleop_action.dim() == 1:
|
||||
teleop_action = teleop_action.unsqueeze(0)
|
||||
|
||||
# self.render()
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return (
|
||||
observation,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
{"action_intervention": teleop_action, "is_intervention": teleop_action is not None},
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""
|
||||
Render the current state of the environment by displaying the robot's camera feeds.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
observation = self.robot.capture_observation()
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close the environment and clean up resources by disconnecting the robot.
|
||||
|
||||
If the robot is currently connected, this method properly terminates the connection to ensure that all
|
||||
associated resources are released.
|
||||
"""
|
||||
if self.robot.is_connected:
|
||||
self.robot.disconnect()
|
||||
|
||||
|
||||
class ActionRepeatWrapper(gym.Wrapper):
|
||||
def __init__(self, env, nb_repeat: int = 1):
|
||||
super().__init__(env)
|
||||
self.nb_repeat = nb_repeat
|
||||
|
||||
def step(self, action):
|
||||
for _ in range(self.nb_repeat):
|
||||
obs, reward, done, truncated, info = self.env.step(action)
|
||||
if done or truncated:
|
||||
break
|
||||
return obs, reward, done, truncated, info
|
||||
|
||||
|
||||
class RewardWrapper(gym.Wrapper):
|
||||
def __init__(self, env, reward_classifier, device: torch.device = "cuda"):
|
||||
"""
|
||||
Wrapper to add reward prediction to the environment, it use a trained classifer.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
reward_classifier: The reward classifier model
|
||||
device: The device to run the model on
|
||||
"""
|
||||
self.env = env
|
||||
|
||||
# NOTE: We got 15% speedup by compiling the model
|
||||
self.reward_classifier = torch.compile(reward_classifier)
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def step(self, action):
|
||||
observation, _, terminated, truncated, info = self.env.step(action)
|
||||
images = [
|
||||
observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
if "image" in key
|
||||
]
|
||||
start_time = time.perf_counter()
|
||||
with torch.inference_mode():
|
||||
reward = (
|
||||
self.reward_classifier.predict_reward(images, threshold=0.8)
|
||||
if self.reward_classifier is not None
|
||||
else 0.0
|
||||
)
|
||||
info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time)
|
||||
|
||||
# logging.info(f"Reward: {reward}")
|
||||
|
||||
if reward == 1.0:
|
||||
terminated = True
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class JointMaskingActionSpace(gym.Wrapper):
|
||||
def __init__(self, env, mask):
|
||||
"""
|
||||
Wrapper to mask out dimensions of the action space.
|
||||
|
||||
Args:
|
||||
env: The environment to wrap
|
||||
mask: Binary mask array where 0 indicates dimensions to remove
|
||||
"""
|
||||
super().__init__(env)
|
||||
|
||||
# Validate mask matches action space
|
||||
|
||||
# Keep only dimensions where mask is 1
|
||||
self.active_dims = np.where(mask)[0]
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Box):
|
||||
if len(mask) != env.action_space.shape[0]:
|
||||
raise ValueError("Mask length must match action space dimensions")
|
||||
low = env.action_space.low[self.active_dims]
|
||||
high = env.action_space.high[self.active_dims]
|
||||
self.action_space = gym.spaces.Box(low=low, high=high, dtype=env.action_space.dtype)
|
||||
|
||||
if isinstance(env.action_space, gym.spaces.Tuple):
|
||||
if len(mask) != env.action_space[0].shape[0]:
|
||||
raise ValueError("Mask length must match action space 0 dimensions")
|
||||
|
||||
low = env.action_space[0].low[self.active_dims]
|
||||
high = env.action_space[0].high[self.active_dims]
|
||||
action_space_masked = gym.spaces.Box(low=low, high=high, dtype=env.action_space[0].dtype)
|
||||
self.action_space = gym.spaces.Tuple((action_space_masked, env.action_space[1]))
|
||||
# Create new action space with masked dimensions
|
||||
|
||||
def action(self, action):
|
||||
"""
|
||||
Convert masked action back to full action space.
|
||||
|
||||
Args:
|
||||
action: Action in masked space. For Tuple spaces, the first element is masked.
|
||||
|
||||
Returns:
|
||||
Action in original space with masked dims set to 0.
|
||||
"""
|
||||
|
||||
# Determine whether we are handling a Tuple space or a Box.
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
# Extract the masked component from the tuple.
|
||||
masked_action = action[0] if isinstance(action, tuple) else action
|
||||
# Create a full action for the Box element.
|
||||
full_box_action = np.zeros(self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype)
|
||||
full_box_action[self.active_dims] = masked_action
|
||||
# Return a tuple with the reconstructed Box action and the unchanged remainder.
|
||||
return (full_box_action, action[1])
|
||||
else:
|
||||
# For Box action spaces.
|
||||
masked_action = action if not isinstance(action, tuple) else action[0]
|
||||
full_action = np.zeros(self.env.action_space.shape, dtype=self.env.action_space.dtype)
|
||||
full_action[self.active_dims] = masked_action
|
||||
return full_action
|
||||
|
||||
def step(self, action):
|
||||
action = self.action(action)
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
if "action_intervention" in info and info["action_intervention"] is not None:
|
||||
if info["action_intervention"].dim() == 1:
|
||||
info["action_intervention"] = info["action_intervention"][self.active_dims]
|
||||
else:
|
||||
info["action_intervention"] = info["action_intervention"][:, self.active_dims]
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class TimeLimitWrapper(gym.Wrapper):
|
||||
def __init__(self, env, control_time_s, fps):
|
||||
self.env = env
|
||||
self.control_time_s = control_time_s
|
||||
self.fps = fps
|
||||
|
||||
self.last_timestamp = 0.0
|
||||
self.episode_time_in_s = 0.0
|
||||
|
||||
self.max_episode_steps = int(self.control_time_s * self.fps)
|
||||
|
||||
self.current_step = 0
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
time_since_last_step = time.perf_counter() - self.last_timestamp
|
||||
self.episode_time_in_s += time_since_last_step
|
||||
self.last_timestamp = time.perf_counter()
|
||||
self.current_step += 1
|
||||
# check if last timestep took more time than the expected fps
|
||||
if 1.0 / time_since_last_step < self.fps:
|
||||
logging.debug(f"Current timestep exceeded expected fps {self.fps}")
|
||||
|
||||
if self.episode_time_in_s > self.control_time_s:
|
||||
# if self.current_step >= self.max_episode_steps:
|
||||
# Terminated = True
|
||||
terminated = True
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
self.episode_time_in_s = 0.0
|
||||
self.last_timestamp = time.perf_counter()
|
||||
self.current_step = 0
|
||||
return self.env.reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class ImageCropResizeWrapper(gym.Wrapper):
|
||||
def __init__(self, env, crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], resize_size=None):
|
||||
super().__init__(env)
|
||||
self.env = env
|
||||
self.crop_params_dict = crop_params_dict
|
||||
print(f"obs_keys , {self.env.observation_space}")
|
||||
print(f"crop params dict {crop_params_dict.keys()}")
|
||||
for key_crop in crop_params_dict:
|
||||
if key_crop not in self.env.observation_space.keys(): # noqa: SIM118
|
||||
raise ValueError(f"Key {key_crop} not in observation space")
|
||||
for key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
new_shape = (top + height, left + width)
|
||||
self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape)
|
||||
|
||||
self.resize_size = resize_size
|
||||
if self.resize_size is None:
|
||||
self.resize_size = (128, 128)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
for k in self.crop_params_dict:
|
||||
device = obs[k].device
|
||||
|
||||
# Check for NaNs before processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} before crop and resize")
|
||||
|
||||
if device == torch.device("mps:0"):
|
||||
obs[k] = obs[k].cpu()
|
||||
|
||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||
obs[k] = F.resize(obs[k], self.resize_size)
|
||||
|
||||
# Check for NaNs after processing
|
||||
if torch.isnan(obs[k]).any():
|
||||
logging.error(f"NaN values detected in observation {k} after crop and resize")
|
||||
|
||||
obs[k] = obs[k].to(device)
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
obs, info = self.env.reset(seed=seed, options=options)
|
||||
for k in self.crop_params_dict:
|
||||
device = obs[k].device
|
||||
if device == torch.device("mps:0"):
|
||||
obs[k] = obs[k].cpu()
|
||||
obs[k] = F.crop(obs[k], *self.crop_params_dict[k])
|
||||
obs[k] = F.resize(obs[k], self.resize_size)
|
||||
obs[k] = obs[k].to(device)
|
||||
return obs, info
|
||||
|
||||
|
||||
class ConvertToLeRobotObservation(gym.ObservationWrapper):
|
||||
def __init__(self, env, device):
|
||||
super().__init__(env)
|
||||
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
observation = preprocess_observation(observation)
|
||||
|
||||
observation = {
|
||||
key: observation[key].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
for key in observation
|
||||
}
|
||||
observation = {k: torch.tensor(v, device=self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class KeyboardInterfaceWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.listener = None
|
||||
self.events = {
|
||||
"exit_early": False,
|
||||
"pause_policy": False,
|
||||
"reset_env": False,
|
||||
"human_intervention_step": False,
|
||||
"episode_success": False,
|
||||
}
|
||||
self.event_lock = Lock() # Thread-safe access to events
|
||||
self._init_keyboard_listener()
|
||||
|
||||
def _init_keyboard_listener(self):
|
||||
"""Initialize keyboard listener if not in headless mode"""
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
return
|
||||
try:
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
with self.event_lock:
|
||||
try:
|
||||
if key == keyboard.Key.right or key == keyboard.Key.esc:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
self.events["exit_early"] = True
|
||||
return
|
||||
if hasattr(key, "char") and key.char == "s":
|
||||
print("Key 's' pressed. Episode success triggered.")
|
||||
self.events["episode_success"] = True
|
||||
return
|
||||
if key == keyboard.Key.space and not self.events["exit_early"]:
|
||||
if not self.events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
self.events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
return
|
||||
if self.events["pause_policy"] and not self.events["human_intervention_step"]:
|
||||
self.events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
return
|
||||
if self.events["pause_policy"] and self.events["human_intervention_step"]:
|
||||
self.events["pause_policy"] = False
|
||||
self.events["human_intervention_step"] = False
|
||||
print("Space key pressed for a third time.")
|
||||
log_say("Continuing with policy actions.", play_sounds=True)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
self.listener = keyboard.Listener(on_press=on_press)
|
||||
self.listener.start()
|
||||
except ImportError:
|
||||
logging.warning("Could not import pynput. Keyboard interface will not be available.")
|
||||
self.listener = None
|
||||
|
||||
def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]:
|
||||
is_intervention = False
|
||||
terminated_by_keyboard = False
|
||||
|
||||
# Extract policy_action if needed
|
||||
if isinstance(self.env.action_space, gym.spaces.Tuple):
|
||||
policy_action = action[0]
|
||||
|
||||
# Check the event flags without holding the lock for too long.
|
||||
with self.event_lock:
|
||||
if self.events["exit_early"]:
|
||||
terminated_by_keyboard = True
|
||||
pause_policy = self.events["pause_policy"]
|
||||
|
||||
if pause_policy:
|
||||
# Now, wait for human_intervention_step without holding the lock
|
||||
while True:
|
||||
with self.event_lock:
|
||||
if self.events["human_intervention_step"]:
|
||||
is_intervention = True
|
||||
break
|
||||
time.sleep(0.1) # Check more frequently if desired
|
||||
|
||||
# Execute the step in the underlying environment
|
||||
obs, reward, terminated, truncated, info = self.env.step((policy_action, is_intervention))
|
||||
|
||||
# Override reward and termination if episode success event triggered
|
||||
with self.event_lock:
|
||||
if self.events["episode_success"]:
|
||||
reward = 1
|
||||
terminated_by_keyboard = True
|
||||
|
||||
return obs, reward, terminated or terminated_by_keyboard, truncated, info
|
||||
|
||||
def reset(self, **kwargs) -> Tuple[Any, Dict]:
|
||||
"""
|
||||
Reset the environment and clear any pending events
|
||||
"""
|
||||
with self.event_lock:
|
||||
self.events = {k: False for k in self.events}
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Properly clean up the keyboard listener when the environment is closed
|
||||
"""
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
super().close()
|
||||
|
||||
|
||||
class ResetWrapper(gym.Wrapper):
|
||||
def __init__(
|
||||
self, env: HILSerlRobotEnv, reset_fn: Optional[Callable[[], None]] = None, reset_time_s: float = 5
|
||||
):
|
||||
super().__init__(env)
|
||||
self.reset_fn = reset_fn
|
||||
self.reset_time_s = reset_time_s
|
||||
|
||||
self.robot = self.unwrapped.robot
|
||||
self.init_pos = self.unwrapped.initial_follower_position
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
if self.reset_fn is not None:
|
||||
self.reset_fn(self.env)
|
||||
else:
|
||||
log_say(f"Manually reset the environment for {self.reset_time_s} seconds.", play_sounds=True)
|
||||
start_time = time.perf_counter()
|
||||
while time.perf_counter() - start_time < self.reset_time_s:
|
||||
self.robot.teleop_step()
|
||||
|
||||
log_say("Manual reseting of the environment done.", play_sounds=True)
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
|
||||
class BatchCompitableWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
|
||||
def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
for key in observation:
|
||||
if "image" in key and observation[key].dim() == 3:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
if "state" in key and observation[key].dim() == 1:
|
||||
observation[key] = observation[key].unsqueeze(0)
|
||||
return observation
|
||||
|
||||
|
||||
# TODO: REMOVE TH
|
||||
|
||||
|
||||
def make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
cfg,
|
||||
n_envs: int = 1,
|
||||
) -> gym.vector.VectorEnv:
|
||||
"""
|
||||
Factory function to create a vectorized robot environment.
|
||||
|
||||
Args:
|
||||
robot: Robot instance to control
|
||||
reward_classifier: Classifier model for computing rewards
|
||||
cfg: Configuration object containing environment parameters
|
||||
n_envs: Number of environments to create in parallel. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
A vectorized gym environment with all the necessary wrappers applied.
|
||||
"""
|
||||
if "maniskill" in cfg.env.name:
|
||||
logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN")
|
||||
env = make_maniskill(
|
||||
cfg=cfg,
|
||||
n_envs=1,
|
||||
)
|
||||
return env
|
||||
# Create base environment
|
||||
env = HILSerlRobotEnv(
|
||||
robot=robot,
|
||||
display_cameras=cfg.env.wrapper.display_cameras,
|
||||
delta=cfg.env.wrapper.delta_action,
|
||||
use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions,
|
||||
)
|
||||
|
||||
# Add observation and image processing
|
||||
env = ConvertToLeRobotObservation(env=env, device=cfg.device)
|
||||
if cfg.env.wrapper.crop_params_dict is not None:
|
||||
env = ImageCropResizeWrapper(
|
||||
env=env, crop_params_dict=cfg.env.wrapper.crop_params_dict, resize_size=cfg.env.wrapper.resize_size
|
||||
)
|
||||
|
||||
# Add reward computation and control wrappers
|
||||
env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device)
|
||||
env = TimeLimitWrapper(env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps)
|
||||
env = KeyboardInterfaceWrapper(env=env)
|
||||
env = ResetWrapper(env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s)
|
||||
env = JointMaskingActionSpace(env=env, mask=cfg.env.wrapper.joint_masking_action_space)
|
||||
env = BatchCompitableWrapper(env=env)
|
||||
|
||||
return env
|
||||
|
||||
# batched version of the env that returns an observation of shape (b, c)
|
||||
|
||||
|
||||
def get_classifier(pretrained_path, config_path, device="mps"):
|
||||
if pretrained_path is None or config_path is None:
|
||||
return None
|
||||
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
cfg = init_hydra_config(config_path)
|
||||
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths
|
||||
model = Classifier(classifier_config)
|
||||
model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict())
|
||||
model = model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def replay_episode(env, repo_id, root=None, episode=0):
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
local_files_only = root is not None
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"][:4]
|
||||
print(action)
|
||||
env.step((action / env.unwrapped.delta, False))
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / 10 - dt_s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--fps", type=int, default=30, help="control frequency")
|
||||
parser.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-cameras", help=("Whether to display the camera feed while the rollout is happening")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-pretrained-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the pretrained classifier weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reward-classifier-config-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a yaml config file that is necessary to build the reward classifier model.",
|
||||
)
|
||||
parser.add_argument("--env-path", type=str, default=None, help="Path to the env yaml file")
|
||||
parser.add_argument("--env-overrides", type=str, default=None, help="Overrides for the env yaml file")
|
||||
parser.add_argument("--control-time-s", type=float, default=20, help="Maximum episode length in seconds")
|
||||
parser.add_argument("--reset-follower-pos", type=int, default=1, help="Reset follower between episodes")
|
||||
parser.add_argument("--replay-repo-id", type=str, default=None, help="Repo ID of the episode to replay")
|
||||
parser.add_argument("--replay-root", type=str, default=None, help="Root of the dataset to replay")
|
||||
parser.add_argument("--replay-episode", type=int, default=0, help="Episode to replay")
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
|
||||
reward_classifier = get_classifier(
|
||||
args.reward_classifier_pretrained_path, args.reward_classifier_config_file
|
||||
)
|
||||
user_relative_joint_positions = True
|
||||
|
||||
cfg = init_hydra_config(args.env_path, args.env_overrides)
|
||||
env = make_robot_env(
|
||||
robot,
|
||||
reward_classifier,
|
||||
cfg.env, # .wrapper,
|
||||
)
|
||||
|
||||
env.reset()
|
||||
|
||||
if args.replay_repo_id is not None:
|
||||
replay_episode(env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode)
|
||||
exit()
|
||||
|
||||
# Retrieve the robot's action space for joint commands.
|
||||
action_space_robot = env.action_space.spaces[0]
|
||||
|
||||
# Initialize the smoothed action as a random sample.
|
||||
smoothed_action = action_space_robot.sample()
|
||||
|
||||
# Smoothing coefficient (alpha) defines how much of the new random sample to mix in.
|
||||
# A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth.
|
||||
alpha = 0.4
|
||||
|
||||
while True:
|
||||
start_loop_s = time.perf_counter()
|
||||
# Sample a new random action from the robot's action space.
|
||||
new_random_action = action_space_robot.sample()
|
||||
# Update the smoothed action using an exponential moving average.
|
||||
smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action
|
||||
|
||||
# Execute the step: wrap the NumPy action in a torch tensor.
|
||||
obs, reward, terminated, truncated, info = env.step((torch.from_numpy(smoothed_action), False))
|
||||
if terminated or truncated:
|
||||
env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_s
|
||||
busy_wait(1 / args.fps - dt_s)
|
||||
58
lerobot/scripts/server/hilserl.proto
Normal file
58
lerobot/scripts/server/hilserl.proto
Normal file
@@ -0,0 +1,58 @@
|
||||
// !/usr/bin/env python
|
||||
|
||||
// Copyright 2024 The HuggingFace Inc. team.
|
||||
// All rights reserved.
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package hil_serl;
|
||||
|
||||
// LearnerService: the Actor calls this to push transitions.
|
||||
// The Learner implements this service.
|
||||
service LearnerService {
|
||||
// Actor -> Learner to store transitions
|
||||
rpc SendTransition(Transition) returns (Empty);
|
||||
rpc SendInteractionMessage(InteractionMessage) returns (Empty);
|
||||
}
|
||||
|
||||
// ActorService: the Learner calls this to push parameters.
|
||||
// The Actor implements this service.
|
||||
service ActorService {
|
||||
// Learner -> Actor to send new parameters
|
||||
rpc StreamTransition(Empty) returns (stream ActorInformation) {};
|
||||
rpc SendParameters(Parameters) returns (Empty);
|
||||
}
|
||||
|
||||
|
||||
message ActorInformation {
|
||||
oneof data {
|
||||
Transition transition = 1;
|
||||
InteractionMessage interaction_message = 2;
|
||||
}
|
||||
}
|
||||
|
||||
// Messages
|
||||
message Transition {
|
||||
bytes transition_bytes = 1;
|
||||
}
|
||||
|
||||
message Parameters {
|
||||
bytes parameter_bytes = 1;
|
||||
}
|
||||
|
||||
message InteractionMessage {
|
||||
bytes interaction_message_bytes = 1;
|
||||
}
|
||||
|
||||
message Empty {}
|
||||
48
lerobot/scripts/server/hilserl_pb2.py
Normal file
48
lerobot/scripts/server/hilserl_pb2.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# NO CHECKED-IN PROTOBUF GENCODE
|
||||
# source: hilserl.proto
|
||||
# Protobuf Python Version: 5.29.0
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import runtime_version as _runtime_version
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
from google.protobuf.internal import builder as _builder
|
||||
_runtime_version.ValidateProtobufRuntimeVersion(
|
||||
_runtime_version.Domain.PUBLIC,
|
||||
5,
|
||||
29,
|
||||
0,
|
||||
'',
|
||||
'hilserl.proto'
|
||||
)
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"\x83\x01\n\x10\x41\x63torInformation\x12*\n\ntransition\x18\x01 \x01(\x0b\x32\x14.hil_serl.TransitionH\x00\x12;\n\x13interaction_message\x18\x02 \x01(\x0b\x32\x1c.hil_serl.InteractionMessageH\x00\x42\x06\n\x04\x64\x61ta\"&\n\nTransition\x12\x18\n\x10transition_bytes\x18\x01 \x01(\x0c\"%\n\nParameters\x12\x17\n\x0fparameter_bytes\x18\x01 \x01(\x0c\"7\n\x12InteractionMessage\x12!\n\x19interaction_message_bytes\x18\x01 \x01(\x0c\"\x07\n\x05\x45mpty2\x92\x01\n\x0eLearnerService\x12\x37\n\x0eSendTransition\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty2\x8c\x01\n\x0c\x41\x63torService\x12\x43\n\x10StreamTransition\x12\x0f.hil_serl.Empty\x1a\x1a.hil_serl.ActorInformation\"\x00\x30\x01\x12\x37\n\x0eSendParameters\x12\x14.hil_serl.Parameters\x1a\x0f.hil_serl.Emptyb\x06proto3')
|
||||
|
||||
_globals = globals()
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals)
|
||||
if not _descriptor._USE_C_DESCRIPTORS:
|
||||
DESCRIPTOR._loaded_options = None
|
||||
_globals['_ACTORINFORMATION']._serialized_start=28
|
||||
_globals['_ACTORINFORMATION']._serialized_end=159
|
||||
_globals['_TRANSITION']._serialized_start=161
|
||||
_globals['_TRANSITION']._serialized_end=199
|
||||
_globals['_PARAMETERS']._serialized_start=201
|
||||
_globals['_PARAMETERS']._serialized_end=238
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_start=240
|
||||
_globals['_INTERACTIONMESSAGE']._serialized_end=295
|
||||
_globals['_EMPTY']._serialized_start=297
|
||||
_globals['_EMPTY']._serialized_end=304
|
||||
_globals['_LEARNERSERVICE']._serialized_start=307
|
||||
_globals['_LEARNERSERVICE']._serialized_end=453
|
||||
_globals['_ACTORSERVICE']._serialized_start=456
|
||||
_globals['_ACTORSERVICE']._serialized_end=596
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
269
lerobot/scripts/server/hilserl_pb2_grpc.py
Normal file
269
lerobot/scripts/server/hilserl_pb2_grpc.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
||||
"""Client and server classes corresponding to protobuf-defined services."""
|
||||
import grpc
|
||||
import warnings
|
||||
|
||||
import hilserl_pb2 as hilserl__pb2
|
||||
|
||||
GRPC_GENERATED_VERSION = '1.70.0'
|
||||
GRPC_VERSION = grpc.__version__
|
||||
_version_not_supported = False
|
||||
|
||||
try:
|
||||
from grpc._utilities import first_version_is_lower
|
||||
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
||||
except ImportError:
|
||||
_version_not_supported = True
|
||||
|
||||
if _version_not_supported:
|
||||
raise RuntimeError(
|
||||
f'The grpc package installed is at version {GRPC_VERSION},'
|
||||
+ f' but the generated code in hilserl_pb2_grpc.py depends on'
|
||||
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
||||
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
||||
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
||||
)
|
||||
|
||||
|
||||
class LearnerServiceStub(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.SendTransition = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
request_serializer=hilserl__pb2.Transition.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
self.SendInteractionMessage = channel.unary_unary(
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
request_serializer=hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class LearnerServiceServicer(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
def SendTransition(self, request, context):
|
||||
"""Actor -> Learner to store transitions
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendInteractionMessage(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_LearnerServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'SendTransition': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendTransition,
|
||||
request_deserializer=hilserl__pb2.Transition.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
'SendInteractionMessage': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendInteractionMessage,
|
||||
request_deserializer=hilserl__pb2.InteractionMessage.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.LearnerService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class LearnerService(object):
|
||||
"""LearnerService: the Actor calls this to push transitions.
|
||||
The Learner implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def SendTransition(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendTransition',
|
||||
hilserl__pb2.Transition.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendInteractionMessage(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.LearnerService/SendInteractionMessage',
|
||||
hilserl__pb2.InteractionMessage.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceStub(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def __init__(self, channel):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
channel: A grpc.Channel.
|
||||
"""
|
||||
self.StreamTransition = channel.unary_stream(
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
request_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.ActorInformation.FromString,
|
||||
_registered_method=True)
|
||||
self.SendParameters = channel.unary_unary(
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
request_serializer=hilserl__pb2.Parameters.SerializeToString,
|
||||
response_deserializer=hilserl__pb2.Empty.FromString,
|
||||
_registered_method=True)
|
||||
|
||||
|
||||
class ActorServiceServicer(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
def StreamTransition(self, request, context):
|
||||
"""Learner -> Actor to send new parameters
|
||||
"""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
def SendParameters(self, request, context):
|
||||
"""Missing associated documentation comment in .proto file."""
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details('Method not implemented!')
|
||||
raise NotImplementedError('Method not implemented!')
|
||||
|
||||
|
||||
def add_ActorServiceServicer_to_server(servicer, server):
|
||||
rpc_method_handlers = {
|
||||
'StreamTransition': grpc.unary_stream_rpc_method_handler(
|
||||
servicer.StreamTransition,
|
||||
request_deserializer=hilserl__pb2.Empty.FromString,
|
||||
response_serializer=hilserl__pb2.ActorInformation.SerializeToString,
|
||||
),
|
||||
'SendParameters': grpc.unary_unary_rpc_method_handler(
|
||||
servicer.SendParameters,
|
||||
request_deserializer=hilserl__pb2.Parameters.FromString,
|
||||
response_serializer=hilserl__pb2.Empty.SerializeToString,
|
||||
),
|
||||
}
|
||||
generic_handler = grpc.method_handlers_generic_handler(
|
||||
'hil_serl.ActorService', rpc_method_handlers)
|
||||
server.add_generic_rpc_handlers((generic_handler,))
|
||||
server.add_registered_method_handlers('hil_serl.ActorService', rpc_method_handlers)
|
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API.
|
||||
class ActorService(object):
|
||||
"""ActorService: the Learner calls this to push parameters.
|
||||
The Actor implements this service.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def StreamTransition(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_stream(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/StreamTransition',
|
||||
hilserl__pb2.Empty.SerializeToString,
|
||||
hilserl__pb2.ActorInformation.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
|
||||
@staticmethod
|
||||
def SendParameters(request,
|
||||
target,
|
||||
options=(),
|
||||
channel_credentials=None,
|
||||
call_credentials=None,
|
||||
insecure=False,
|
||||
compression=None,
|
||||
wait_for_ready=None,
|
||||
timeout=None,
|
||||
metadata=None):
|
||||
return grpc.experimental.unary_unary(
|
||||
request,
|
||||
target,
|
||||
'/hil_serl.ActorService/SendParameters',
|
||||
hilserl__pb2.Parameters.SerializeToString,
|
||||
hilserl__pb2.Empty.FromString,
|
||||
options,
|
||||
channel_credentials,
|
||||
insecure,
|
||||
call_credentials,
|
||||
compression,
|
||||
wait_for_ready,
|
||||
timeout,
|
||||
metadata,
|
||||
_registered_method=True)
|
||||
680
lerobot/scripts/server/learner_server.py
Normal file
680
lerobot/scripts/server/learner_server.py
Normal file
@@ -0,0 +1,680 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
# For profiling only
|
||||
import datetime
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_global_random_state,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
interaction_message_queue = queue.Queue()
|
||||
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
):
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
)
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
transition_thread.start()
|
||||
param_push_thread.start()
|
||||
param_push_thread.join()
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
# if policy.config.vision_encoder_name is not None:
|
||||
# if policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor's "SendParameters" method
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
for k in next_state:
|
||||
if torch.isnan(next_state[k]).any():
|
||||
logging.error(f"next_state[{k}] contains NaN values")
|
||||
if torch.isnan(actions).any():
|
||||
logging.error("actions contains NaN values")
|
||||
|
||||
|
||||
def add_actor_information_and_train(
|
||||
cfg,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
policy_lock: Lock,
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
and logs training progress in an online reinforcement learning setup.
|
||||
|
||||
This function continuously:
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
policy_lock (Lock): A threading lock to ensure safe policy updates.
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
while True:
|
||||
while not transition_queue.empty():
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
image_features, next_image_features = None, None
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
with record_function("encoder_forward"):
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step")
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
"Optimization frequency loop [Hz]": frequency_for_one_optimization_step,
|
||||
"Optimization step": optimization_step,
|
||||
},
|
||||
mode="train",
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None,
|
||||
identifier=step_identifier,
|
||||
interaction_step=interaction_step,
|
||||
)
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = logger.log_dir / "dataset"
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset"
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
profiler.step()
|
||||
|
||||
if optimization_step >= 50: # Profile for 500 steps
|
||||
profiler.stop()
|
||||
break
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy_lock = Lock()
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
# policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
policy_lock,
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
)
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
767
lerobot/scripts/server/learner_server_mp.py
Normal file
767
lerobot/scripts/server/learner_server_mp.py
Normal file
@@ -0,0 +1,767 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from multiprocessing import Process, Event
|
||||
from torch.multiprocessing import Queue, Lock, set_start_method
|
||||
import logging.handlers
|
||||
from pathlib import Path
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import nn
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_global_random_state,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
# Initialize these in the main process
|
||||
# transition_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
# interaction_message_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
policy_lock = Lock()
|
||||
replay_buffer_lock = Lock()
|
||||
offline_replay_buffer_lock = Lock()
|
||||
# logging_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
|
||||
def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
if not cfg.resume:
|
||||
if Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. "
|
||||
"Use `resume=true` to resume training."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n"
|
||||
"Checkpoint configuration takes precedence."
|
||||
)
|
||||
|
||||
checkpoint_cfg.resume = True
|
||||
return checkpoint_cfg
|
||||
|
||||
|
||||
def load_training_state(
|
||||
cfg: DictConfig,
|
||||
logger: Logger,
|
||||
optimizers: Optimizer | dict,
|
||||
):
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
for k, v in training_state["optimizer"].items():
|
||||
optimizers[k].load_state_dict(v)
|
||||
else:
|
||||
optimizers.load_state_dict(training_state["optimizer"])
|
||||
|
||||
set_global_random_state({k: training_state[k] for k in get_global_random_state()})
|
||||
return training_state["step"], training_state["interaction_step"]
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
storage_device=device,
|
||||
use_shared_memory=True
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
use_shared_memory=True
|
||||
)
|
||||
|
||||
|
||||
def start_learner_threads(
|
||||
cfg: DictConfig,
|
||||
device: str,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict,
|
||||
policy: SACPolicy,
|
||||
log_dir: Path,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
logging_queue: Queue,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
# Move policy to shared memory
|
||||
policy.share_memory()
|
||||
|
||||
server_process = Process(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transition_process = Process(
|
||||
target=train_offpolicy_rl,
|
||||
daemon=True,
|
||||
args=(
|
||||
cfg,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
log_dir,
|
||||
resume_optimization_step,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_process = Process(
|
||||
target=learner_push_parameters,
|
||||
args=(
|
||||
policy,
|
||||
actor_ip,
|
||||
port,
|
||||
15
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
fill_replay_buffers_process = Process(
|
||||
target=fill_replay_buffers,
|
||||
args=(
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
logging_queue,
|
||||
resume_interaction_step,
|
||||
device,
|
||||
)
|
||||
)
|
||||
|
||||
return server_process, transition_process, param_push_process, fill_replay_buffers_process
|
||||
|
||||
|
||||
|
||||
def stream_transitions_from_actor(
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
host: str,
|
||||
port: int,
|
||||
):
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
while True:
|
||||
try:
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
except grpc.RpcError:
|
||||
time.sleep(2) # Retry connection
|
||||
continue
|
||||
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module,
|
||||
actor_host="127.0.0.1",
|
||||
actor_port=50052,
|
||||
seconds_between_pushes=5
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
# if policy.config.vision_encoder_name is not None:
|
||||
# if policy.config.freeze_vision_encoder:
|
||||
# params_dict: dict[str, torch.Tensor] = {
|
||||
# k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
# }
|
||||
# else:
|
||||
# raise NotImplementedError(
|
||||
# "Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
# )
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor's "SendParameters" method
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
|
||||
def fill_replay_buffers(
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
logger_queue: Queue,
|
||||
resume_interaction_step: int | None,
|
||||
device: str,
|
||||
):
|
||||
while True:
|
||||
while not transition_queue.empty():
|
||||
transition_list = transition_queue.get() # Increase timeout
|
||||
for transition in transition_list:
|
||||
transition = move_transition_to_device(transition, device=device)
|
||||
with replay_buffer_lock:
|
||||
replay_buffer.add(**transition)
|
||||
|
||||
if transition.get("complementary_info", {}).get("is_intervention"):
|
||||
with offline_replay_buffer_lock:
|
||||
offline_replay_buffer.add(**transition)
|
||||
|
||||
while not interaction_message_queue.empty():
|
||||
interaction_message = interaction_message_queue.get()
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
if resume_interaction_step is not None:
|
||||
interaction_message["Interaction step"] += resume_interaction_step
|
||||
logger_queue.put({
|
||||
'info': interaction_message,
|
||||
'step_key': "Interaction step"
|
||||
})
|
||||
|
||||
|
||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
for k in next_state:
|
||||
if torch.isnan(next_state[k]).any():
|
||||
logging.error(f"next_state[{k}] contains NaN values")
|
||||
if torch.isnan(actions).any():
|
||||
logging.error("actions contains NaN values")
|
||||
|
||||
|
||||
def train_offpolicy_rl(
|
||||
cfg,
|
||||
replay_buffer: ReplayBuffer,
|
||||
offline_replay_buffer: ReplayBuffer,
|
||||
batch_size: int,
|
||||
optimizers: dict[str, torch.optim.Optimizer],
|
||||
policy: nn.Module,
|
||||
log_dir: Path,
|
||||
logging_queue: Queue,
|
||||
resume_optimization_step: int | None = None,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
and logs training progress in an online reinforcement learning setup.
|
||||
|
||||
This function continuously:
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
**NOTE:**
|
||||
- This function performs multiple responsibilities (data transfer, training, and logging).
|
||||
It should ideally be split into smaller functions in the future.
|
||||
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
|
||||
significantly reduces performance. Instead, this function executes all operations in a single thread.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
device (str): The computing device (`"cpu"` or `"cuda"`).
|
||||
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
|
||||
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
|
||||
batch_size (int): The number of transitions to sample per training step.
|
||||
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
|
||||
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
|
||||
log_dir (Path): The directory to save the log files.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
# are divided by 200. So we need to have a single thread that does all the work.
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
|
||||
# Wait for stream process to be ready
|
||||
while True:
|
||||
|
||||
with replay_buffer_lock:
|
||||
logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
# logging.info(f"Size of replay buffer: {len(replay_buffer)}")
|
||||
# logging.info(f"Size of offline replay buffer: {len(offline_replay_buffer)}")
|
||||
|
||||
image_features, next_image_features = None, None
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
with replay_buffer_lock:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
with offline_replay_buffer_lock:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
with replay_buffer_lock:
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
with offline_replay_buffer_lock:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
# Precompute encoder features from the frozen vision encoder if enabled
|
||||
if policy.config.vision_encoder_name is not None and policy.config.freeze_vision_encoder:
|
||||
with torch.no_grad():
|
||||
image_features = (
|
||||
policy.actor.encoder(observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
next_image_features = (
|
||||
policy.actor.encoder(next_observations)
|
||||
if policy.actor.encoder is not None
|
||||
else None
|
||||
)
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
image_features=image_features,
|
||||
next_image_features=next_image_features,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if optimization_step % cfg.training.policy_update_freq == 0:
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
with policy_lock:
|
||||
loss_actor = policy.compute_loss_actor(observations=observations, image_features=image_features)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations, image_features=image_features)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
policy.update_target_networks()
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
training_infos["Optimization step"] = optimization_step
|
||||
logging_queue.put({
|
||||
'info': training_infos,
|
||||
'step_key': "Optimization step"
|
||||
})
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % cfg.training.log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
# needed (choose 6 as a minimum for consistency without being overkill).
|
||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
logging_queue.put({
|
||||
'checkpoint': {
|
||||
'step': optimization_step,
|
||||
'identifier': step_identifier,
|
||||
}
|
||||
})
|
||||
|
||||
# TODO : temporarly save replay buffer here, remove later when on the robot
|
||||
# We want to control this with the keyboard inputs
|
||||
dataset_dir = log_dir / "dataset"
|
||||
if dataset_dir.exists() and dataset_dir.is_dir():
|
||||
shutil.rmtree(
|
||||
dataset_dir,
|
||||
)
|
||||
with replay_buffer_lock:
|
||||
replay_buffer.to_lerobot_dataset(
|
||||
cfg.dataset_repo_id, fps=cfg.fps, root=dataset_dir
|
||||
)
|
||||
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
**NOTE:**
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
|
||||
# Initialize multiprocessing with spawn method for better compatibility
|
||||
set_start_method('spawn', force=True)
|
||||
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
# Create our logger instance in the main process
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
cfg = handle_resume_logic(cfg, out_dir)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy intance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
# TODO: At some point we should just need make sac policy
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
# compile policy
|
||||
# policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, logger, device)
|
||||
batch_size = cfg.training.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
state_keys=cfg.policy.input_shapes.keys(),
|
||||
action_mask=active_action_dims,
|
||||
action_delta=cfg.env.wrapper.delta_action,
|
||||
use_shared_memory=True
|
||||
)
|
||||
batch_size = batch_size // 2
|
||||
|
||||
transition_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
interaction_message_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
logging_queue = Queue(maxsize=1_000_000) # Set a maximum size
|
||||
|
||||
processes = start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
replay_buffer,
|
||||
offline_replay_buffer,
|
||||
batch_size,
|
||||
optimizers,
|
||||
policy,
|
||||
logger.log_dir,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
logging_queue,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
)
|
||||
|
||||
|
||||
# Consume log messages from the logging_queue in the main process
|
||||
for p in processes:
|
||||
p.start()
|
||||
|
||||
latest_interaction_step = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
while True:
|
||||
try:
|
||||
message = logging_queue.get(timeout=1)
|
||||
if 'checkpoint' in message:
|
||||
ckpt = message['checkpoint']
|
||||
logger.save_checkpoint(
|
||||
ckpt['step'],
|
||||
policy,
|
||||
optimizers,
|
||||
scheduler=None,
|
||||
identifier=ckpt['identifier'],
|
||||
interaction_step=latest_interaction_step,
|
||||
)
|
||||
else:
|
||||
if 'Interaction step' in message['info']:
|
||||
latest_interaction_step = message['info']['Interaction step']
|
||||
logger.log_dict(
|
||||
message['info'],
|
||||
mode="train",
|
||||
custom_step_key=message['step_key']
|
||||
)
|
||||
except queue.Empty:
|
||||
continue
|
||||
except KeyboardInterrupt:
|
||||
# Cleanup any remaining processes (if you want to terminate them here)
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join()
|
||||
break
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
174
lerobot/scripts/server/maniskill_manipulator.py
Normal file
174
lerobot/scripts/server/maniskill_manipulator.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from typing import Any
|
||||
|
||||
"""Make ManiSkill3 gym environment"""
|
||||
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
|
||||
|
||||
|
||||
|
||||
def preprocess_maniskill_observation(observations: dict[str, np.ndarray]) -> dict[str, torch.Tensor]:
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
Args:
|
||||
observation: Dictionary of observation batches from a Gym vector environment.
|
||||
Returns:
|
||||
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
|
||||
"""
|
||||
# map to expected inputs for the policy
|
||||
return_observations = {}
|
||||
# TODO: You have to merge all tensors from agent key and extra key
|
||||
# You don't keep sensor param key in the observation
|
||||
# And you keep sensor data rgb
|
||||
q_pos = observations["agent"]["qpos"]
|
||||
q_vel = observations["agent"]["qvel"]
|
||||
tcp_pos = observations["extra"]["tcp_pose"]
|
||||
img = observations["sensor_data"]["base_camera"]["rgb"]
|
||||
|
||||
_, h, w, c = img.shape
|
||||
assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}"
|
||||
|
||||
# sanity check that images are uint8
|
||||
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
|
||||
|
||||
# convert to channel first of type float32 in range [0,1]
|
||||
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
|
||||
img = img.type(torch.float32)
|
||||
img /= 255
|
||||
|
||||
state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1)
|
||||
|
||||
return_observations["observation.image"] = img
|
||||
return_observations["observation.state"] = state
|
||||
return return_observations
|
||||
|
||||
|
||||
class ManiSkillObservationWrapper(gym.ObservationWrapper):
|
||||
def __init__(self, env, device: torch.device = "cuda"):
|
||||
super().__init__(env)
|
||||
self.device = device
|
||||
|
||||
def observation(self, observation):
|
||||
observation = preprocess_maniskill_observation(observation)
|
||||
observation = {k: v.to(self.device) for k, v in observation.items()}
|
||||
return observation
|
||||
|
||||
|
||||
class ManiSkillCompat(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
new_action_space_shape = env.action_space.shape[-1]
|
||||
new_low = np.squeeze(env.action_space.low, axis=0)
|
||||
new_high = np.squeeze(env.action_space.high, axis=0)
|
||||
self.action_space = gym.spaces.Box(low=new_low, high=new_high, shape=(new_action_space_shape,))
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[Any, dict[str, Any]]:
|
||||
options = {}
|
||||
return super().reset(seed=seed, options=options)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||
reward = reward.item()
|
||||
terminated = terminated.item()
|
||||
truncated = truncated.item()
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
class ManiSkillActionWrapper(gym.ActionWrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self.action_space = gym.spaces.Tuple(spaces=(env.action_space, gym.spaces.Discrete(2)))
|
||||
|
||||
def action(self, action):
|
||||
action, telop = action
|
||||
return action
|
||||
|
||||
|
||||
class ManiSkillMultiplyActionWrapper(gym.Wrapper):
|
||||
def __init__(self, env, multiply_factor: float = 1):
|
||||
super().__init__(env)
|
||||
self.multiply_factor = multiply_factor
|
||||
action_space_agent: gym.spaces.Box = env.action_space[0]
|
||||
action_space_agent.low = action_space_agent.low * multiply_factor
|
||||
action_space_agent.high = action_space_agent.high * multiply_factor
|
||||
self.action_space = gym.spaces.Tuple(spaces=(action_space_agent, gym.spaces.Discrete(2)))
|
||||
|
||||
def step(self, action):
|
||||
if isinstance(action, tuple):
|
||||
action, telop = action
|
||||
else:
|
||||
telop = 0
|
||||
action = action / self.multiply_factor
|
||||
obs, reward, terminated, truncated, info = self.env.step((action, telop))
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
|
||||
def make_maniskill(
|
||||
cfg: DictConfig,
|
||||
n_envs: int | None = None,
|
||||
) -> gym.Env:
|
||||
"""
|
||||
Factory function to create a ManiSkill environment with standard wrappers.
|
||||
|
||||
Args:
|
||||
task: Name of the ManiSkill task
|
||||
obs_mode: Observation mode (rgb, rgbd, etc)
|
||||
control_mode: Control mode for the robot
|
||||
render_mode: Rendering mode
|
||||
sensor_configs: Camera sensor configurations
|
||||
n_envs: Number of parallel environments
|
||||
|
||||
Returns:
|
||||
A wrapped ManiSkill environment
|
||||
"""
|
||||
|
||||
env = gym.make(
|
||||
cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size},
|
||||
num_envs=n_envs,
|
||||
)
|
||||
|
||||
env = ManiSkillObservationWrapper(env, device=cfg.env.device)
|
||||
env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False)
|
||||
env._max_episode_steps = env.max_episode_steps = 50 # gym_utils.find_max_episode_steps_value(env)
|
||||
env.unwrapped.metadata["render_fps"] = 20
|
||||
env = ManiSkillCompat(env)
|
||||
env = ManiSkillActionWrapper(env)
|
||||
env = ManiSkillMultiplyActionWrapper(env, multiply_factor=10.0)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
import hydra
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="lerobot/configs/env/maniskill_example.yaml")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize config
|
||||
with hydra.initialize(version_base=None, config_path="../../configs"):
|
||||
cfg = hydra.compose(config_name="env/maniskill_example.yaml")
|
||||
|
||||
env = make_maniskill(
|
||||
task=cfg.env.task,
|
||||
obs_mode=cfg.env.obs,
|
||||
control_mode=cfg.env.control_mode,
|
||||
render_mode=cfg.env.render_mode,
|
||||
sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size},
|
||||
)
|
||||
|
||||
print("env done")
|
||||
obs, info = env.reset()
|
||||
random_action = env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = env.step(random_action)
|
||||
@@ -93,6 +93,17 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "sac":
|
||||
optimizer = torch.optim.Adam(
|
||||
[
|
||||
{"params": policy.actor.parameters(), "lr": policy.config.actor_lr},
|
||||
{"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr},
|
||||
{"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr},
|
||||
]
|
||||
)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
@@ -311,6 +322,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
|
||||
logging.info("make_dataset")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
# TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment
|
||||
# i.e., pusht
|
||||
if "task_index" in offline_dataset.hf_dataset[0]:
|
||||
offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"])
|
||||
|
||||
if isinstance(offline_dataset, MultiLeRobotDataset):
|
||||
logging.info(
|
||||
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
|
||||
|
||||
408
lerobot/scripts/train_hilserl_classifier.py
Normal file
408
lerobot/scripts/train_hilserl_classifier.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.autograd import profiler
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server.buffer import random_shift
|
||||
|
||||
|
||||
def get_model(cfg, logger): # noqa I001
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def support_amp(device: torch.device, cfg: DictConfig) -> bool:
|
||||
# Check if the device supports AMP
|
||||
# Here is an example of the issue that says that MPS doesn't support AMP properply
|
||||
return cfg.training.use_amp and device.type in ("cuda", "cpu")
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
images = [random_shift(img, 4) for img in images]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
inference_times = []
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(),
|
||||
):
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
|
||||
with (
|
||||
profiler.profile(record_shapes=True) as prof,
|
||||
profiler.record_function("model_inference"),
|
||||
):
|
||||
outputs = model(images)
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
else:
|
||||
outputs = model(images)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < cfg.eval.num_samples_to_log:
|
||||
for i in range(min(cfg.eval.num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
**{
|
||||
f"image_{img_key}": wandb.Image(images[img_idx][i].cpu())
|
||||
for img_idx, img_key in enumerate(cfg.training.image_keys)
|
||||
},
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
print(f"Average validation loss {avg_loss}, and accuracy {accuracy}")
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[list(s.values()) for s in samples],
|
||||
columns=list(samples[0].keys()),
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
if len(inference_times) > 0:
|
||||
eval_info["inference_time_avg"] = np.mean(inference_times)
|
||||
eval_info["inference_time_median"] = np.median(inference_times)
|
||||
eval_info["inference_time_std"] = np.std(inference_times)
|
||||
eval_info["inference_time_batch_size"] = val_loader.batch_size
|
||||
|
||||
print(
|
||||
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
|
||||
)
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
|
||||
if not cfg.training.profile_inference_time:
|
||||
return
|
||||
|
||||
iters = cfg.training.profile_inference_time_iters
|
||||
inference_times = []
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=RandomSampler(dataset),
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
|
||||
x = next(iter(loader))
|
||||
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
|
||||
|
||||
# Warm up
|
||||
for _ in range(10):
|
||||
_ = model(x)
|
||||
|
||||
# sync the device
|
||||
if device.type == "cuda":
|
||||
torch.cuda.synchronize()
|
||||
elif device.type == "mps":
|
||||
torch.mps.synchronize()
|
||||
|
||||
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
|
||||
_ = model(x)
|
||||
|
||||
inference_times.append(
|
||||
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
|
||||
)
|
||||
|
||||
inference_times = np.array(inference_times)
|
||||
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
|
||||
print(
|
||||
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
|
||||
)
|
||||
if logger._cfg.wandb.enable:
|
||||
logger.log_dict(
|
||||
{
|
||||
"inference_time_benchmark_avg": avg,
|
||||
"inference_time_benchmark_median": median,
|
||||
"inference_time_benchmark_std": std,
|
||||
},
|
||||
step + 1,
|
||||
mode="eval",
|
||||
)
|
||||
|
||||
return avg, median, std
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = hydra.core.hydra_config.HydraConfig.get().run.dir + "frozen_resnet10_2"
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id, root=cfg.dataset_root, local_files_only=cfg.local_files_only
|
||||
)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
n_total = len(dataset)
|
||||
n_train = int(cfg.train_split_proportion * len(dataset))
|
||||
train_dataset = torch.utils.data.Subset(dataset, range(0, n_train))
|
||||
val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total))
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=device.type == "cuda",
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch+1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
benchmark_inference_time(model, dataset, logger, cfg, device, step)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
586
lerobot/scripts/train_sac.py
Normal file
586
lerobot/scripts/train_sac.py
Normal file
@@ -0,0 +1,586 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import functools
|
||||
import logging
|
||||
import random
|
||||
from pprint import pformat
|
||||
from typing import Callable, Optional, Sequence, TypedDict
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.envs.factory import make_env, make_maniskill_env
|
||||
from lerobot.common.envs.utils import preprocess_maniskill_observation, preprocess_observation
|
||||
from lerobot.common.logger import Logger, log_output_dir
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.policies.utils import get_device_from_parameters
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.eval import eval_policy
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg, policy):
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
|
||||
params=policy.actor.parameters_to_optimize,
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
# We wrap policy log temperature in list because this is a torch tensor and not a nn.Module
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: float
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: bool
|
||||
complementary_info: dict[str, torch.Tensor] = None
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Perform a per-image random crop over a batch of images in a vectorized way.
|
||||
(Same as shown previously.)
|
||||
"""
|
||||
B, C, H, W = images.shape
|
||||
crop_h, crop_w = output_size
|
||||
|
||||
if crop_h > H or crop_w > W:
|
||||
raise ValueError(
|
||||
f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})."
|
||||
)
|
||||
|
||||
tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device)
|
||||
lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device)
|
||||
|
||||
rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1)
|
||||
cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1)
|
||||
|
||||
rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w)
|
||||
cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w)
|
||||
|
||||
images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C)
|
||||
|
||||
# Gather pixels
|
||||
cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :]
|
||||
# cropped_hwcn => (B, crop_h, crop_w, C)
|
||||
|
||||
cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w)
|
||||
return cropped
|
||||
|
||||
|
||||
def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
"""Vectorized random shift, imgs: (B,C,H,W), pad: #pixels"""
|
||||
_, _, h, w = images.shape
|
||||
images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate")
|
||||
return random_crop_vectorized(images=images, output_size=(h, w))
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
image_augmentation_function: Optional[Callable] = None,
|
||||
use_drq: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
capacity (int): Maximum number of transitions to store in the buffer.
|
||||
device (str): The device where the tensors will be moved ("cuda:0" or "cpu").
|
||||
state_keys (List[str]): The list of keys that appear in `state` and `next_state`.
|
||||
image_augmentation_function (Optional[Callable]): A function that takes a batch of images
|
||||
and returns a batch of augmented images. If None, a default augmentation function is used.
|
||||
use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer.
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.device = device
|
||||
self.memory: list[Transition] = []
|
||||
self.position = 0
|
||||
|
||||
# If no state_keys provided, default to an empty list
|
||||
# (you can handle this differently if needed)
|
||||
self.state_keys = state_keys if state_keys is not None else []
|
||||
if image_augmentation_function is None:
|
||||
self.image_augmentation_function = functools.partial(random_shift, pad=4)
|
||||
self.use_drq = use_drq
|
||||
|
||||
def add(
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
reward: float,
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
complementary_info: Optional[dict[str, torch.Tensor]] = None,
|
||||
):
|
||||
"""Saves a transition."""
|
||||
if len(self.memory) < self.capacity:
|
||||
self.memory.append(None)
|
||||
|
||||
# Create and store the Transition
|
||||
self.memory[self.position] = Transition(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
self.position: int = (self.position + 1) % self.capacity
|
||||
|
||||
# TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
lerobot_dataset: LeRobotDataset,
|
||||
device: str = "cuda:0",
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> "ReplayBuffer":
|
||||
"""
|
||||
Convert a LeRobotDataset into a ReplayBuffer.
|
||||
|
||||
Args:
|
||||
lerobot_dataset (LeRobotDataset): The dataset to convert.
|
||||
device (str): The device . Defaults to "cuda:0".
|
||||
state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
ReplayBuffer: The replay buffer with offline dataset transitions.
|
||||
"""
|
||||
# We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from
|
||||
# a replay buffer than from a lerobot dataset.
|
||||
replay_buffer = cls(capacity=len(lerobot_dataset), device=device, state_keys=state_keys)
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Fill the replay buffer with the lerobot dataset transitions
|
||||
for data in list_transition:
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=data["action"],
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
)
|
||||
return replay_buffer
|
||||
|
||||
@staticmethod
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Optional[Sequence[str]] = None,
|
||||
) -> list[Transition]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
The dataset to convert. Each item in the dataset is expected to have
|
||||
at least the following keys:
|
||||
{
|
||||
"action": ...
|
||||
"next.reward": ...
|
||||
"next.done": ...
|
||||
"episode_index": ...
|
||||
}
|
||||
plus whatever your 'state_keys' specify.
|
||||
|
||||
state_keys (Optional[Sequence[str]]):
|
||||
The dataset keys to include in 'state' and 'next_state'. Their names
|
||||
will be kept as-is in the output transitions. E.g.
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
"""
|
||||
|
||||
# If not provided, you can either raise an error or define a default:
|
||||
if state_keys is None:
|
||||
raise ValueError("You must provide a list of keys in `state_keys` that define your 'state'.")
|
||||
|
||||
transitions: list[Transition] = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
for i in tqdm(range(num_frames)):
|
||||
current_sample = dataset[i]
|
||||
|
||||
# ----- 1) Current state -----
|
||||
current_state: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = current_sample[key]
|
||||
current_state[key] = val.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 2) Action -----
|
||||
action = current_sample["action"].unsqueeze(0) # Add batch dimension
|
||||
|
||||
# ----- 3) Reward and done -----
|
||||
reward = float(current_sample["next.reward"].item()) # ensure float
|
||||
done = bool(current_sample["next.done"].item()) # ensure bool
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
# Otherwise (done=True or next sample crosses to a new episode), next_state = current_state.
|
||||
next_state = current_state # default
|
||||
if not done and (i < num_frames - 1):
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] == current_sample["episode_index"]:
|
||||
# Build next_state from the same keys
|
||||
next_state_data: dict[str, torch.Tensor] = {}
|
||||
for key in state_keys:
|
||||
val = next_sample[key]
|
||||
next_state_data[key] = val.unsqueeze(0) # Add batch dimension
|
||||
next_state = next_state_data
|
||||
|
||||
# ----- Construct the Transition -----
|
||||
transition = Transition(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
|
||||
def sample(self, batch_size: int) -> BatchTransition:
|
||||
"""Sample a random batch of transitions and collate them into batched tensors."""
|
||||
list_of_transitions = random.sample(self.memory, batch_size)
|
||||
|
||||
# -- Build batched states --
|
||||
batch_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_state[key] = torch.cat([t["state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_state[key] = self.image_augmentation_function(batch_state[key])
|
||||
|
||||
# -- Build batched actions --
|
||||
batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to(self.device)
|
||||
|
||||
# -- Build batched rewards --
|
||||
batch_rewards = torch.tensor([t["reward"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# -- Build batched next states --
|
||||
batch_next_state = {}
|
||||
for key in self.state_keys:
|
||||
batch_next_state[key] = torch.cat([t["next_state"][key] for t in list_of_transitions], dim=0).to(
|
||||
self.device
|
||||
)
|
||||
if key.startswith("observation.image") and self.use_drq:
|
||||
batch_next_state[key] = self.image_augmentation_function(batch_next_state[key])
|
||||
|
||||
# -- Build batched dones --
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
batch_dones = torch.tensor([t["done"] for t in list_of_transitions], dtype=torch.float32).to(
|
||||
self.device
|
||||
)
|
||||
|
||||
# Return a BatchTransition typed dict
|
||||
return BatchTransition(
|
||||
state=batch_state,
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
)
|
||||
|
||||
|
||||
def concatenate_batch_transitions(
|
||||
left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition
|
||||
) -> BatchTransition:
|
||||
"""NOTE: Be careful it change the left_batch_transitions in place"""
|
||||
left_batch_transitions["state"] = {
|
||||
key: torch.cat([left_batch_transitions["state"][key], right_batch_transition["state"][key]], dim=0)
|
||||
for key in left_batch_transitions["state"]
|
||||
}
|
||||
left_batch_transitions["action"] = torch.cat(
|
||||
[left_batch_transitions["action"], right_batch_transition["action"]], dim=0
|
||||
)
|
||||
left_batch_transitions["reward"] = torch.cat(
|
||||
[left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0
|
||||
)
|
||||
left_batch_transitions["next_state"] = {
|
||||
key: torch.cat(
|
||||
[left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], dim=0
|
||||
)
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
)
|
||||
return left_batch_transitions
|
||||
|
||||
|
||||
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
if job_name is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
init_logging()
|
||||
logging.info(pformat(OmegaConf.to_container(cfg)))
|
||||
|
||||
# Create an env dedicated to online episodes collection from policy rollout.
|
||||
# online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
|
||||
# NOTE: Off policy algorithm are efficient enought to use a single environment
|
||||
logging.info("make_env online")
|
||||
# online_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
online_env = make_maniskill_env(cfg, n_envs=1)
|
||||
if cfg.training.eval_freq > 0:
|
||||
logging.info("make_env eval")
|
||||
# eval_env = make_env(cfg, n_envs=1)
|
||||
# TODO: Remove the import of maniskill and unifiy with make env
|
||||
eval_env = make_maniskill_env(cfg, n_envs=1)
|
||||
|
||||
# TODO: Add a way to resume training
|
||||
|
||||
# log metrics to terminal and wandb
|
||||
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
device=device,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
|
||||
# TODO: Handle resume
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
obs, info = online_env.reset()
|
||||
|
||||
# HACK for maniskill
|
||||
# obs = preprocess_observation(obs)
|
||||
obs = preprocess_maniskill_observation(obs)
|
||||
obs = {key: obs[key].to(device, non_blocking=True) for key in obs}
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
|
||||
batch_size = cfg.training.batch_size
|
||||
|
||||
if cfg.dataset_repo_id is not None:
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys()
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
# NOTE: At some point we should use a wrapper to handle the observation
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
action = policy.select_action(batch=obs)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.cpu().numpy())
|
||||
else:
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
# HACK
|
||||
action = torch.tensor(action, dtype=torch.float32).to(device, non_blocking=True)
|
||||
|
||||
# HACK: For maniskill
|
||||
# next_obs = preprocess_observation(next_obs)
|
||||
next_obs = preprocess_maniskill_observation(next_obs)
|
||||
next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs}
|
||||
sum_reward_episode += float(reward[0])
|
||||
# Because we are using a single environment
|
||||
# we can safely assume that the episode is done
|
||||
if done[0] or truncated[0]:
|
||||
logging.info(f"Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
logger.log_dict({"Sum episode reward": sum_reward_episode}, interaction_step)
|
||||
sum_reward_episode = 0
|
||||
# HACK: This is for maniskill
|
||||
logging.info(
|
||||
f"global step {interaction_step}: episode success: {info['success'].float().item()} \n"
|
||||
)
|
||||
logger.log_dict({"Episode success": info["success"].float().item()}, interaction_step)
|
||||
|
||||
replay_buffer.add(
|
||||
state=obs,
|
||||
action=action,
|
||||
reward=float(reward[0]),
|
||||
next_state=next_obs,
|
||||
done=done[0],
|
||||
)
|
||||
obs = next_obs
|
||||
|
||||
if interaction_step < cfg.training.online_step_before_learning:
|
||||
continue
|
||||
for _ in range(cfg.policy.utd_ratio - 1):
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(batch, batch_offline)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
batch = replay_buffer.sample(batch_size)
|
||||
if cfg.dataset_repo_id is not None:
|
||||
batch_offline = offline_replay_buffer.sample(batch_size)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch["action"]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
)
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
training_infos = {}
|
||||
training_infos["loss_critic"] = loss_critic.item()
|
||||
|
||||
if interaction_step % cfg.training.policy_update_freq == 0:
|
||||
# TD3 Trick
|
||||
for _ in range(cfg.training.policy_update_freq):
|
||||
loss_actor = policy.compute_loss_actor(observations=observations)
|
||||
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
|
||||
if interaction_step % cfg.training.log_freq == 0:
|
||||
logger.log_dict(training_infos, interaction_step, mode="train")
|
||||
|
||||
policy.update_target_networks()
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
|
||||
def train_cli(cfg: dict):
|
||||
train(
|
||||
cfg,
|
||||
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
|
||||
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
|
||||
)
|
||||
|
||||
|
||||
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
|
||||
from hydra import compose, initialize
|
||||
|
||||
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
||||
initialize(config_path=config_path)
|
||||
cfg = compose(config_name=config_name)
|
||||
train(cfg, out_dir=out_dir, job_name=job_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_cli()
|
||||
@@ -53,20 +53,29 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
from flask import Flask, redirect, render_template, request, url_for
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import IterableNamespace
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
dataset: LeRobotDataset | IterableNamespace | None,
|
||||
episodes: list[int] | None,
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
@@ -76,10 +85,50 @@ def run_server(
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
def hommepage(dataset=dataset):
|
||||
if dataset:
|
||||
dataset_namespace, dataset_name = dataset.repo_id.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=0,
|
||||
)
|
||||
)
|
||||
|
||||
dataset_param, episode_param = None, None
|
||||
all_params = request.args
|
||||
if "dataset" in all_params:
|
||||
dataset_param = all_params["dataset"]
|
||||
if "episode" in all_params:
|
||||
episode_param = int(all_params["episode"])
|
||||
|
||||
if dataset_param:
|
||||
dataset_namespace, dataset_name = dataset_param.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=episode_param if episode_param is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
featured_datasets = [
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/columbia_cairlab_pusht_real",
|
||||
"lerobot/taco_play",
|
||||
]
|
||||
return render_template(
|
||||
"visualize_dataset_homepage.html",
|
||||
featured_datasets=featured_datasets,
|
||||
lerobot_datasets=available_datasets,
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
|
||||
def show_first_episode(dataset_namespace, dataset_name):
|
||||
first_episode_id = 0
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
@@ -90,30 +139,85 @@ def run_server(
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
dataset = get_dataset_info(repo_id)
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
major_version = int(match.group(1))
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_frames,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_frames,
|
||||
"num_episodes": dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.video_path.format(
|
||||
episode_chunk=int(episode_id) // dataset.chunks_size,
|
||||
video_key=video_key,
|
||||
episode_index=episode_id,
|
||||
),
|
||||
"filename": video_key,
|
||||
}
|
||||
for video_key in video_keys
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
)
|
||||
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=False,
|
||||
episode_data_csv_str=episode_data_csv_str,
|
||||
columns=columns,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
@@ -124,46 +228,69 @@ def get_ep_csv_fname(episode_id: int):
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
||||
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
columns = []
|
||||
|
||||
has_state = "observation.state" in dataset.features
|
||||
has_action = "action" in dataset.features
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
||||
selected_columns.remove("timestamp")
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
for column_name in selected_columns:
|
||||
dim_state = (
|
||||
dataset.meta.shapes[column_name][0]
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
header += [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
else:
|
||||
column_names = [f"motor_{i}" for i in range(dim_state)]
|
||||
columns.append({"key": column_name, "value": column_names})
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
.with_format("pandas")
|
||||
)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
|
||||
rows = np.hstack(
|
||||
(
|
||||
np.expand_dims(data["timestamp"], axis=1),
|
||||
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||
)
|
||||
).tolist()
|
||||
|
||||
# Convert data to CSV string
|
||||
csv_buffer = StringIO()
|
||||
csv_writer = csv.writer(csv_buffer)
|
||||
# Write header
|
||||
csv_writer.writerow(header)
|
||||
# Write data rows
|
||||
csv_writer.writerows(rows)
|
||||
csv_string = csv_buffer.getvalue()
|
||||
|
||||
return csv_string, columns
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
@@ -175,9 +302,31 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
dataset_info["repo_id"] = repo_id
|
||||
return IterableNamespace(dataset_info)
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int] = None,
|
||||
dataset: LeRobotDataset | None,
|
||||
episodes: list[int] | None = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
@@ -186,11 +335,11 @@ def visualize_dataset_html(
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||
# Create a temporary directory that will be automatically cleaned up
|
||||
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
@@ -201,28 +350,29 @@ def visualize_dataset_html(
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
if dataset is None:
|
||||
if serve:
|
||||
run_server(
|
||||
dataset=None,
|
||||
episodes=None,
|
||||
host=host,
|
||||
port=port,
|
||||
static_folder=static_dir,
|
||||
template_folder=template_dir,
|
||||
)
|
||||
else:
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -231,7 +381,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
default=None,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -246,6 +396,12 @@ def main():
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-from-hf-hub",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Load videos and parquet files from HF Hub rather than local system.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
@@ -287,11 +443,19 @@ def main():
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
visualize_dataset_html(dataset, **kwargs)
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
68
lerobot/templates/visualize_dataset_homepage.html
Normal file
68
lerobot/templates/visualize_dataset_homepage.html
Normal file
@@ -0,0 +1,68 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Interactive Video Background Page</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
if (trimmedValue) {
|
||||
window.location.href = `/${trimmedValue}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<div class="fixed inset-0 w-full h-full overflow-hidden">
|
||||
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
|
||||
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
|
||||
Your browser does not support HTML5 video.
|
||||
</video>
|
||||
</div>
|
||||
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
|
||||
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
|
||||
<div class="text-center mb-8">
|
||||
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
|
||||
|
||||
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
|
||||
|
||||
<p class="text-xl mb-4"></p>
|
||||
<div class="text-left inline-block">
|
||||
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
|
||||
<ul class="list-disc list-inside">
|
||||
{% for dataset in featured_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
Go
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<details class="mt-4 max-w-full px-4">
|
||||
<summary>More example datasets</summary>
|
||||
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
|
||||
{% for dataset in lerobot_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -31,11 +31,16 @@
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
|
||||
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
|
||||
</a>
|
||||
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
</a>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
@@ -93,10 +98,35 @@
|
||||
</div>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="flex flex-wrap gap-1">
|
||||
<div class="max-w-32 relative text-sm mb-4 select-none"
|
||||
@click.outside="isVideosDropdownOpen = false">
|
||||
<div
|
||||
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
|
||||
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
|
||||
>
|
||||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
<div
|
||||
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
|
||||
class="p-2 cursor-pointer bg-slate-900"
|
||||
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
|
||||
x-text="option"
|
||||
></div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap gap-x-2 gap-y-6">
|
||||
{% for video_info in videos_info %}
|
||||
<div x-show="!videoCodecError" class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
|
||||
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
@@ -182,12 +212,12 @@
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
<p x-text="`${columns[colIndex].key}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
@@ -197,10 +227,10 @@
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
<p x-text="`${rowLabels[rowIndex]}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
@@ -222,16 +252,20 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const parentOrigin = "https://huggingface.co";
|
||||
const searchParams = new URLSearchParams();
|
||||
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
|
||||
searchParams.set("episode", "{{ episode_id }}");
|
||||
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
|
||||
</script>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
nColumns: 2,
|
||||
nStates: 0,
|
||||
nActions: 0,
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
@@ -241,6 +275,11 @@
|
||||
nVideos: {{ videos_info | length }},
|
||||
nVideoReadyToPlay: 0,
|
||||
videoCodecError: false,
|
||||
isVideosDropdownOpen: false,
|
||||
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
|
||||
videosKeysSelected: [],
|
||||
columns: {{ columns | tojson }},
|
||||
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
@@ -250,11 +289,19 @@
|
||||
if(!canPlayVideos){
|
||||
this.videoCodecError = true;
|
||||
}
|
||||
this.videosKeysSelected = this.videosKeys.map(opt => opt)
|
||||
|
||||
// process CSV data
|
||||
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
|
||||
// Create a Blob with the CSV data
|
||||
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
|
||||
// Create a URL for the Blob
|
||||
const csvUrl = URL.createObjectURL(blob);
|
||||
|
||||
// process CSV data
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
@@ -275,21 +322,17 @@
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
||||
this.nActions = seriesNames.length - this.nStates;
|
||||
const colors = [];
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
// colors for "state" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
// colors for "action" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
||||
colors.push(color);
|
||||
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
for(const column of this.columns){
|
||||
const nValues = column.value.length;
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightness += 35;
|
||||
}
|
||||
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
@@ -316,17 +359,19 @@
|
||||
return [];
|
||||
}
|
||||
const rows = [];
|
||||
const nRows = Math.max(this.nStates, this.nActions);
|
||||
const nRows = Math.max(...this.columns.map(column => column.value.length));
|
||||
let rowIndex = 0;
|
||||
while(rowIndex < nRows){
|
||||
const row = [];
|
||||
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
|
||||
const nullCell = { isNull: true };
|
||||
const stateValueIdx = rowIndex;
|
||||
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
// row consists of [state value, action value]
|
||||
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
||||
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
||||
let idx = rowIndex;
|
||||
for(const column of this.columns){
|
||||
const nColumn = column.value.length;
|
||||
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
|
||||
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
}
|
||||
rowIndex += 1;
|
||||
rows.push(row);
|
||||
}
|
||||
|
||||
153
poetry.lock
generated
153
poetry.lock
generated
@@ -3139,6 +3139,27 @@ dev = ["changelist (==0.5)"]
|
||||
lint = ["pre-commit (==3.7.0)"]
|
||||
test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "lightning-utilities"
|
||||
version = "0.11.9"
|
||||
description = "Lightning toolbox for across the our ecosystem."
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "lightning_utilities-0.11.9-py3-none-any.whl", hash = "sha256:ac6d4e9e28faf3ff4be997876750fee10dc604753dbc429bf3848a95c5d7e0d2"},
|
||||
{file = "lightning_utilities-0.11.9.tar.gz", hash = "sha256:f5052b81344cc2684aa9afd74b7ce8819a8f49a858184ec04548a5a109dfd053"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = ">=17.1"
|
||||
setuptools = "*"
|
||||
typing-extensions = "*"
|
||||
|
||||
[package.extras]
|
||||
cli = ["fire"]
|
||||
docs = ["requests (>=2.0.0)"]
|
||||
typing = ["mypy (>=1.0.0)", "types-setuptools"]
|
||||
|
||||
[[package]]
|
||||
name = "llvmlite"
|
||||
version = "0.43.0"
|
||||
@@ -6798,6 +6819,38 @@ webencodings = ">=0.4"
|
||||
doc = ["sphinx", "sphinx_rtd_theme"]
|
||||
test = ["pytest", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tokenizers"
|
||||
version = "0.21.0"
|
||||
description = ""
|
||||
optional = true
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3c4c93eae637e7d2aaae3d376f06085164e1660f89304c0ab2b1d08a406636b2"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:f53ea537c925422a2e0e92a24cce96f6bc5046bbef24a1652a5edc8ba975f62e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b177fb54c4702ef611de0c069d9169f0004233890e0c4c5bd5508ae05abf193"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6b43779a269f4629bebb114e19c3fca0223296ae9fea8bb9a7a6c6fb0657ff8e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9aeb255802be90acfd363626753fda0064a8df06031012fe7d52fd9a905eb00e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d8b09dbeb7a8d73ee204a70f94fc06ea0f17dcf0844f16102b9f414f0b7463ba"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:400832c0904f77ce87c40f1a8a27493071282f785724ae62144324f171377273"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84ca973b3a96894d1707e189c14a774b701596d579ffc7e69debfc036a61a04"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:eb7202d231b273c34ec67767378cd04c767e967fda12d4a9e36208a34e2f137e"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:089d56db6782a73a27fd8abf3ba21779f5b85d4a9f35e3b493c7bbcbbf0d539b"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:c87ca3dc48b9b1222d984b6b7490355a6fdb411a2d810f6f05977258400ddb74"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:4145505a973116f91bc3ac45988a92e618a6f83eb458f49ea0790df94ee243ff"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win32.whl", hash = "sha256:eb1702c2f27d25d9dd5b389cc1f2f51813e99f8ca30d9e25348db6585a97e24a"},
|
||||
{file = "tokenizers-0.21.0-cp39-abi3-win_amd64.whl", hash = "sha256:87841da5a25a3a5f70c102de371db120f41873b854ba65e52bccd57df5a3780c"},
|
||||
{file = "tokenizers-0.21.0.tar.gz", hash = "sha256:ee0894bf311b75b0c03079f33859ae4b2334d675d4e93f5a4132e1eae2834fe4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
huggingface-hub = ">=0.16.4,<1.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["tokenizers[testing]"]
|
||||
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
|
||||
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.2"
|
||||
@@ -6863,6 +6916,34 @@ typing-extensions = ">=4.8.0"
|
||||
opt-einsum = ["opt-einsum (>=3.3)"]
|
||||
optree = ["optree (>=0.11.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchmetrics"
|
||||
version = "1.6.0"
|
||||
description = "PyTorch native Metrics"
|
||||
optional = true
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "torchmetrics-1.6.0-py3-none-any.whl", hash = "sha256:a508cdd87766cedaaf55a419812bf9f493aff8fffc02cc19df5a8e2e7ccb942a"},
|
||||
{file = "torchmetrics-1.6.0.tar.gz", hash = "sha256:aebba248708fb90def20cccba6f55bddd134a58de43fb22b0c5ca0f3a89fa984"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
lightning-utilities = ">=0.8.0"
|
||||
numpy = ">1.20.0"
|
||||
packaging = ">17.1"
|
||||
torch = ">=2.0.0"
|
||||
|
||||
[package.extras]
|
||||
all = ["SciencePlots (>=2.0.0)", "gammatone (>=1.0.0)", "ipadic (>=1.0.0)", "librosa (>=0.10.0)", "matplotlib (>=3.6.0)", "mecab-python3 (>=1.0.6)", "mypy (==1.13.0)", "nltk (>3.8.1)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "torch (==2.5.1)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
audio = ["gammatone (>=1.0.0)", "librosa (>=0.10.0)", "numpy (<2.0)", "onnxruntime (>=1.12.0)", "pesq (>=0.0.4)", "pystoi (>=0.4.0)", "requests (>=2.19.0)", "torchaudio (>=2.0.1)"]
|
||||
detection = ["pycocotools (>2.0.0)", "torchvision (>=0.15.1)"]
|
||||
dev = ["PyTDC (==0.4.1)", "SciencePlots (>=2.0.0)", "bert-score (==0.3.13)", "dython (==0.7.6)", "dython (>=0.7.8,<0.8.0)", "fairlearn", "fast-bss-eval (>=0.1.0)", "faster-coco-eval (>=1.6.3)", "gammatone (>=1.0.0)", "huggingface-hub (<0.27)", "ipadic (>=1.0.0)", "jiwer (>=2.3.0)", "kornia (>=0.6.7)", "librosa (>=0.10.0)", "lpips (<=0.1.4)", "matplotlib (>=3.6.0)", "mecab-ko (>=1.0.0,<1.1.0)", "mecab-ko-dic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "mir-eval (>=0.6)", "monai (==1.3.2)", "monai (==1.4.0)", "mypy (==1.13.0)", "netcal (>1.0.0)", "nltk (>3.8.1)", "numpy (<2.0)", "numpy (<2.2.0)", "onnxruntime (>=1.12.0)", "pandas (>1.4.0)", "permetrics (==2.0.0)", "pesq (>=0.0.4)", "piq (<=0.8.0)", "pycocotools (>2.0.0)", "pystoi (>=0.4.0)", "pytorch-msssim (==1.0.0)", "regex (>=2021.9.24)", "requests (>=2.19.0)", "rouge-score (>0.1.0)", "sacrebleu (>=2.3.0)", "scikit-image (>=0.19.0)", "scipy (>1.0.0)", "sentencepiece (>=0.2.0)", "sewar (>=0.4.4)", "statsmodels (>0.13.5)", "torch (==2.5.1)", "torch-complex (<0.5.0)", "torch-fidelity (<=0.4.0)", "torchaudio (>=2.0.1)", "torchvision (>=0.15.1)", "tqdm (<4.68.0)", "transformers (>4.4.0)", "transformers (>=4.42.3)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
image = ["scipy (>1.0.0)", "torch-fidelity (<=0.4.0)", "torchvision (>=0.15.1)"]
|
||||
multimodal = ["piq (<=0.8.0)", "transformers (>=4.42.3)"]
|
||||
text = ["ipadic (>=1.0.0)", "mecab-python3 (>=1.0.6)", "nltk (>3.8.1)", "regex (>=2021.9.24)", "sentencepiece (>=0.2.0)", "tqdm (<4.68.0)", "transformers (>4.4.0)"]
|
||||
typing = ["mypy (==1.13.0)", "torch (==2.5.1)", "types-PyYAML", "types-emoji", "types-protobuf", "types-requests", "types-setuptools", "types-six", "types-tabulate"]
|
||||
visual = ["SciencePlots (>=2.0.0)", "matplotlib (>=3.6.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "torchvision"
|
||||
version = "0.19.1"
|
||||
@@ -6956,6 +7037,75 @@ files = [
|
||||
docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"]
|
||||
test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<8.2)", "pytest-mock", "pytest-mypy-testing"]
|
||||
|
||||
[[package]]
|
||||
name = "transformers"
|
||||
version = "4.47.0"
|
||||
description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow"
|
||||
optional = true
|
||||
python-versions = ">=3.9.0"
|
||||
files = [
|
||||
{file = "transformers-4.47.0-py3-none-any.whl", hash = "sha256:a8e1bafdaae69abdda3cad638fe392e37c86d2ce0ecfcae11d60abb8f949ff4d"},
|
||||
{file = "transformers-4.47.0.tar.gz", hash = "sha256:f8ead7a5a4f6937bb507e66508e5e002dc5930f7b6122a9259c37b099d0f3b19"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
filelock = "*"
|
||||
huggingface-hub = ">=0.24.0,<1.0"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
pyyaml = ">=5.1"
|
||||
regex = "!=2019.12.17"
|
||||
requests = "*"
|
||||
safetensors = ">=0.4.1"
|
||||
tokenizers = ">=0.21,<0.22"
|
||||
tqdm = ">=4.27"
|
||||
|
||||
[package.extras]
|
||||
accelerate = ["accelerate (>=0.26.0)"]
|
||||
agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"]
|
||||
all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision"]
|
||||
audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
benchmark = ["optimum-benchmark (>=0.3.0)"]
|
||||
codecarbon = ["codecarbon (==1.2.0)"]
|
||||
deepspeed = ["accelerate (>=0.26.0)", "deepspeed (>=0.9.3)"]
|
||||
deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.26.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "scipy (<1.13.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1,<0.14.0)", "libcst", "librosa", "nltk (<=3.8.1)", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.21,<0.22)", "urllib3 (<2.0.0)"]
|
||||
dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.26.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "libcst", "librosa", "nltk (<=3.8.1)", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rich", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm (<=1.0.11)", "tokenizers (>=0.21,<0.22)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"]
|
||||
flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)", "scipy (<1.13.0)"]
|
||||
flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
ftfy = ["ftfy"]
|
||||
integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"]
|
||||
ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"]
|
||||
modelcreation = ["cookiecutter (==1.7.3)"]
|
||||
natten = ["natten (>=0.14.6,<0.15.0)"]
|
||||
onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"]
|
||||
onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"]
|
||||
optuna = ["optuna"]
|
||||
quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "isort (>=5.5.4)", "libcst", "rich", "ruff (==0.5.1)", "urllib3 (<2.0.0)"]
|
||||
ray = ["ray[tune] (>=2.7.0)"]
|
||||
retrieval = ["datasets (!=2.5.0)", "faiss-cpu"]
|
||||
ruff = ["ruff (==0.5.1)"]
|
||||
sagemaker = ["sagemaker (>=2.31.0)"]
|
||||
sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"]
|
||||
serving = ["fastapi", "pydantic", "starlette", "uvicorn"]
|
||||
sigopt = ["sigopt"]
|
||||
sklearn = ["scikit-learn"]
|
||||
speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "nltk (<=3.8.1)", "parameterized", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-rich", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.5.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"]
|
||||
tf = ["keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow (>2.9,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-cpu = ["keras (>2.9,<2.16)", "keras-nlp (>=0.3.1,<0.14.0)", "onnxconverter-common", "tensorflow-cpu (>2.9,<2.16)", "tensorflow-probability (<0.24)", "tensorflow-text (<2.16)", "tf2onnx"]
|
||||
tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"]
|
||||
tiktoken = ["blobfile", "tiktoken"]
|
||||
timm = ["timm (<=1.0.11)"]
|
||||
tokenizers = ["tokenizers (>=0.21,<0.22)"]
|
||||
torch = ["accelerate (>=0.26.0)", "torch"]
|
||||
torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"]
|
||||
torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"]
|
||||
torchhub = ["filelock", "huggingface-hub (>=0.24.0,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.21,<0.22)", "torch", "tqdm (>=4.27)"]
|
||||
video = ["av (==9.2.0)"]
|
||||
vision = ["Pillow (>=10.0.1,<=15.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "transforms3d"
|
||||
version = "0.4.2"
|
||||
@@ -7558,6 +7708,7 @@ dev = ["debugpy", "pre-commit"]
|
||||
dora = ["gym-dora"]
|
||||
dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
hilserl = ["torchmetrics", "transformers"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
pusht = ["gym-pusht"]
|
||||
stretch = ["hello-robot-stretch-body", "pynput", "pyrealsense2", "pyrender"]
|
||||
@@ -7569,4 +7720,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"
|
||||
content-hash = "44c74163e398e8ff16973957f69a47bb09b789e92ac4d8fb3ab268defab96427"
|
||||
|
||||
@@ -71,6 +71,8 @@ pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platfo
|
||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
transformers = {version = ">=4.47.0", optional = true}
|
||||
torchmetrics = {version = ">=1.6.0", optional = true}
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
@@ -86,6 +88,7 @@ dynamixel = ["dynamixel-sdk", "pynput"]
|
||||
feetech = ["feetech-servo-sdk", "pynput"]
|
||||
intelrealsense = ["pyrealsense2"]
|
||||
stretch = ["hello-robot-stretch-body", "pyrender", "pyrealsense2", "pynput"]
|
||||
hilserl = ["transformers", "torchmetrics"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 110
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot import available_cameras, available_motors, available_robots
|
||||
@@ -124,3 +126,14 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
def pytest_addoption(parser):
|
||||
parser.addoption("--seed", action="store", default="42", help="Set random seed for reproducibility")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed(request):
|
||||
seed = int(request.config.getoption("--seed"))
|
||||
random.seed(seed) # Python random
|
||||
torch.manual_seed(seed) # PyTorch
|
||||
|
||||
38
tests/lerobot/common/datasets/test_utils.py
Normal file
38
tests/lerobot/common/datasets/test_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import AUROC, Accuracy, F1Score, Precision, Recall
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.transforms import ToTensor
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier, ClassifierConfig
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
LR = 0.1
|
||||
EPOCH_NUM = 2
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = torch.device("cuda")
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = torch.device("mps")
|
||||
else:
|
||||
DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def train_evaluate_multiclass_classifier():
|
||||
logging.info(
|
||||
f"Start multiclass classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
multiclass_config = ClassifierConfig(model_name="microsoft/resnet-18", device=DEVICE, num_classes=10)
|
||||
multiclass_classifier = Classifier(multiclass_config)
|
||||
|
||||
trainset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
testset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
multiclass_num_classes = 10
|
||||
epoch = 1
|
||||
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Adam(multiclass_classifier.parameters(), lr=LR)
|
||||
|
||||
multiclass_classifier.train()
|
||||
|
||||
logging.info("Start multiclass classifier training")
|
||||
|
||||
# Training loop
|
||||
while epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = multiclass_classifier(inputs)
|
||||
|
||||
loss = criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
logging.info(f"[Epoch {epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
|
||||
epoch += 1
|
||||
|
||||
print("Multiclass classifier training finished")
|
||||
|
||||
multiclass_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = multiclass_classifier(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
_, predicted = torch.max(outputs.logits, 1)
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(predicted.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(testset)
|
||||
|
||||
logging.info(f"Multiclass classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
accuracy = Accuracy(task="multiclass", num_classes=multiclass_num_classes)
|
||||
precision = Precision(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
recall = Recall(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
f1 = F1Score(task="multiclass", average="weighted", num_classes=multiclass_num_classes)
|
||||
auroc = AUROC(task="multiclass", num_classes=multiclass_num_classes, average="weighted")
|
||||
|
||||
# Calculate metrics
|
||||
acc = accuracy(test_predictions, test_labels)
|
||||
prec = precision(test_predictions, test_labels)
|
||||
rec = recall(test_predictions, test_labels)
|
||||
f1_score = f1(test_predictions, test_labels)
|
||||
auroc_score = auroc(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
def train_evaluate_binary_classifier():
|
||||
logging.info(
|
||||
f"Start binary classifier train eval with {DEVICE} device, batch size {BATCH_SIZE}, learning rate {LR}"
|
||||
)
|
||||
|
||||
target_binary_class = 3
|
||||
|
||||
def one_vs_rest(dataset, target_class):
|
||||
new_targets = []
|
||||
for _, label in dataset:
|
||||
new_label = float(1.0) if label == target_class else float(0.0)
|
||||
new_targets.append(new_label)
|
||||
|
||||
dataset.targets = new_targets # Replace the original labels with the binary ones
|
||||
return dataset
|
||||
|
||||
binary_train_dataset = CIFAR10(root="data", train=True, download=True, transform=ToTensor())
|
||||
binary_test_dataset = CIFAR10(root="data", train=False, download=True, transform=ToTensor())
|
||||
|
||||
# Apply one-vs-rest labeling
|
||||
binary_train_dataset = one_vs_rest(binary_train_dataset, target_binary_class)
|
||||
binary_test_dataset = one_vs_rest(binary_test_dataset, target_binary_class)
|
||||
|
||||
binary_trainloader = DataLoader(binary_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
binary_testloader = DataLoader(binary_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
||||
|
||||
binary_epoch = 1
|
||||
|
||||
binary_config = ClassifierConfig(model_name="microsoft/resnet-50", device=DEVICE)
|
||||
binary_classifier = Classifier(binary_config)
|
||||
|
||||
class_counts = np.bincount(binary_train_dataset.targets)
|
||||
n = len(binary_train_dataset)
|
||||
w0 = n / (2.0 * class_counts[0])
|
||||
w1 = n / (2.0 * class_counts[1])
|
||||
|
||||
binary_criterion = BCEWithLogitsLoss(pos_weight=torch.tensor(w1 / w0))
|
||||
binary_optimizer = Adam(binary_classifier.parameters(), lr=LR)
|
||||
|
||||
binary_classifier.train()
|
||||
|
||||
logging.info("Start binary classifier training")
|
||||
|
||||
# Training loop
|
||||
while binary_epoch < EPOCH_NUM: # loop over the dataset multiple times
|
||||
for i, data in enumerate(binary_trainloader):
|
||||
inputs, labels = data
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
|
||||
# Zero the parameter gradients
|
||||
binary_optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
outputs = binary_classifier(inputs)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
loss.backward()
|
||||
binary_optimizer.step()
|
||||
|
||||
if i % 10 == 0: # print every 10 mini-batches
|
||||
print(f"[Epoch {binary_epoch}, Batch {i}] loss: {loss.item():.3f}")
|
||||
binary_epoch += 1
|
||||
|
||||
logging.info("Binary classifier training finished")
|
||||
logging.info("Start binary classifier evaluation")
|
||||
|
||||
binary_classifier.eval()
|
||||
|
||||
test_loss = 0.0
|
||||
test_labels = []
|
||||
test_pridections = []
|
||||
test_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data in binary_testloader:
|
||||
images, labels = data
|
||||
images, labels = images.to(DEVICE), labels.to(torch.float32).to(DEVICE)
|
||||
outputs = binary_classifier(images)
|
||||
loss = binary_criterion(outputs.logits, labels)
|
||||
test_loss += loss.item() * BATCH_SIZE
|
||||
|
||||
test_labels.extend(labels.cpu())
|
||||
test_pridections.extend(outputs.logits.cpu())
|
||||
test_probs.extend(outputs.probabilities.cpu())
|
||||
|
||||
test_loss = test_loss / len(binary_test_dataset)
|
||||
|
||||
logging.info(f"Binary classifier test loss {test_loss:.3f}")
|
||||
|
||||
test_labels = torch.stack(test_labels)
|
||||
test_predictions = torch.stack(test_pridections)
|
||||
test_probs = torch.stack(test_probs)
|
||||
|
||||
# Calculate metrics
|
||||
acc = Accuracy(task="binary")(test_predictions, test_labels)
|
||||
prec = Precision(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
rec = Recall(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
f1_score = F1Score(task="binary", average="weighted")(test_predictions, test_labels)
|
||||
auroc_score = AUROC(task="binary", average="weighted")(test_probs, test_labels)
|
||||
|
||||
logging.info(f"Accuracy: {acc:.2f}")
|
||||
logging.info(f"Precision: {prec:.2f}")
|
||||
logging.info(f"Recall: {rec:.2f}")
|
||||
logging.info(f"F1 Score: {f1_score:.2f}")
|
||||
logging.info(f"AUROC Score: {auroc_score:.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_evaluate_multiclass_classifier()
|
||||
train_evaluate_binary_classifier()
|
||||
@@ -0,0 +1,85 @@
|
||||
import torch
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import (
|
||||
ClassifierConfig,
|
||||
ClassifierOutput,
|
||||
)
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
def test_classifier_output():
|
||||
output = ClassifierOutput(
|
||||
logits=torch.tensor([1, 2, 3]), probabilities=torch.tensor([0.1, 0.2, 0.3]), hidden_states=None
|
||||
)
|
||||
|
||||
assert (
|
||||
f"{output}"
|
||||
== "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)"
|
||||
)
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_binary_classifier_with_default_params():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_multiclass_classifier():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
num_classes = 5
|
||||
config = ClassifierConfig(num_classes=num_classes)
|
||||
classifier = Classifier(config)
|
||||
|
||||
batch_size = 10
|
||||
|
||||
input = torch.rand(batch_size, 3, 224, 224)
|
||||
output = classifier(input)
|
||||
|
||||
assert output is not None
|
||||
assert output.logits.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.logits).any(), "Tensor contains NaN values"
|
||||
assert output.probabilities.shape == torch.Size([batch_size, num_classes])
|
||||
assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values"
|
||||
assert output.hidden_states.shape == torch.Size([batch_size, 2048])
|
||||
assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values"
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_default_device():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig()
|
||||
assert config.device == "cpu"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("cpu")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
def test_explicit_device_setup():
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
|
||||
config = ClassifierConfig(device="meta")
|
||||
assert config.device == "meta"
|
||||
|
||||
classifier = Classifier(config)
|
||||
for p in classifier.parameters():
|
||||
assert p.device == torch.device("meta")
|
||||
304
tests/test_train_hilserl_classifier.py
Normal file
304
tests/test_train_hilserl_classifier.py
Normal file
@@ -0,0 +1,304 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
train,
|
||||
train_epoch,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.meta = MagicMock()
|
||||
self.meta.stats = {}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=1
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
|
||||
def test_create_balanced_sampler():
|
||||
# Mock dataset with imbalanced classes
|
||||
data = [
|
||||
{"label": 0},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
]
|
||||
dataset = MockDataset(data)
|
||||
cfg = MagicMock()
|
||||
cfg.training.label_key = "label"
|
||||
|
||||
sampler = create_balanced_sampler(dataset, cfg)
|
||||
|
||||
# Get weights from the sampler
|
||||
weights = sampler.weights.float()
|
||||
|
||||
# Check that samples have appropriate weights
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
model = make_dummy_model()
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
def test_validate():
|
||||
model = make_dummy_model()
|
||||
|
||||
# Mock components
|
||||
model.eval = MagicMock()
|
||||
val_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call validate
|
||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
||||
|
||||
# Check that model.eval() was called
|
||||
model.eval.assert_called_once()
|
||||
|
||||
# Check accuracy/eval_info are calculated and of the correct type
|
||||
assert isinstance(accuracy, float)
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
def test_train_epoch_multiple_cameras():
|
||||
model_config = ClassifierConfig(
|
||||
num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel", num_cameras=2
|
||||
)
|
||||
model = Classifier(config=model_config)
|
||||
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image_1": torch.rand(2, 3, 224, 224),
|
||||
"image_2": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_keys = ["image_1", "image_2"]
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.get_model")
|
||||
def test_resume_function(
|
||||
mock_get_model,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
mock_get_last_checkpoint_dir,
|
||||
mock_init_hydra_config,
|
||||
resume,
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="hilserl_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
f"output_dir={tempfile.mkdtemp()}",
|
||||
"wandb.enable=False",
|
||||
f"resume={resume}",
|
||||
"dataset_repo_id=dataset_repo_id",
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_keys=[image]",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
"eval.batch_size=2",
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the init_hydra_config function to return cfg
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
||||
|
||||
# Mock logger
|
||||
logger = MagicMock()
|
||||
resumed_step = 1000
|
||||
if resume:
|
||||
logger.load_last_training_state.return_value = resumed_step
|
||||
else:
|
||||
logger.load_last_training_state.return_value = 0
|
||||
mock_logger.return_value = logger
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_get_model.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
# Check that checkpoint handling methods were called
|
||||
if resume:
|
||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_checkpoint_dir.exists.assert_called_once()
|
||||
logger.load_last_training_state.assert_called_once()
|
||||
else:
|
||||
mock_get_last_checkpoint_dir.assert_not_called()
|
||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
||||
mock_checkpoint_dir.exists.assert_not_called()
|
||||
logger.load_last_training_state.assert_not_called()
|
||||
|
||||
# Collect the steps from logger.log_dict calls
|
||||
train_log_calls = logger.log_dict.call_args_list
|
||||
|
||||
# Extract the steps used in the train logging
|
||||
steps = []
|
||||
for call in train_log_calls:
|
||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
||||
if mode == "train":
|
||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
||||
steps.append(step)
|
||||
|
||||
expected_start_step = resumed_step if resume else 0
|
||||
|
||||
# Calculate expected_steps
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
batch_size = cfg.training.batch_size
|
||||
num_batches = (train_size + batch_size - 1) // batch_size
|
||||
|
||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
||||
|
||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
||||
Reference in New Issue
Block a user