Compare commits
13 Commits
user/rcade
...
tdmpc23
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
14490148f3 | ||
|
|
16edbbdeee | ||
|
|
15090c2544 | ||
|
|
166c1fc776 | ||
|
|
31984645da | ||
|
|
c41ec08ec1 | ||
|
|
a146544765 | ||
|
|
963738d983 | ||
|
|
e0df56de62 | ||
|
|
538455a965 | ||
|
|
172809a502 | ||
|
|
55e4ff6742 | ||
|
|
07e8716315 |
18
README.md
@@ -23,15 +23,15 @@
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">New robot in town: SO-100</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
<p>We just added a new tutorial on how to build a more affordable robot, at the price of $110 per arm!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
|
||||
<p>Follow the link to the <a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">full tutorial for SO-100</a>.</p>
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
@@ -55,9 +55,9 @@
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
<td><img src="media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
@@ -144,7 +144,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
```bash
|
||||
@@ -280,7 +280,7 @@ To use wandb for logging training and evaluation curves, make sure you've run `w
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's install LeRobot. We will next provide a tutorial for assembly.
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
@@ -45,12 +45,46 @@ conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Run this script two times to find the ports (e.g. "/dev/tty.usbmodem58760432961") of your motor buses:
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Then plug your first motor, corresponding to "shoulder_pan" and run this script to set its ID to 1 and set its present position and offset to ~2048 (useful for calibration).
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
@@ -60,7 +94,9 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your motor and plug the second motor, corresponding to "shoulder lift", and set its ID to 2.
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
@@ -70,23 +106,57 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until the gripper with ID 6. Do the same for the motors of the leader arm, starting for ID 1 up to 6.
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
TODO
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras'
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
Without displaying the cameras:
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
@@ -94,7 +164,9 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
With displaying the cameras:
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
@@ -102,7 +174,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with so100.
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -203,6 +275,6 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#so100-arm`.
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
|
||||
280
examples/11_use_moss.md
Normal file
@@ -0,0 +1,280 @@
|
||||
This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-robot-arms) with LeRobot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). For Moss v1, you need to align the holes on the motor horn to the motor spline to be approximately 3, 6, 9 and 12 o'clock.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Moss v1 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Moss v1.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--tags moss tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/moss_test \
|
||||
policy=act_moss_real \
|
||||
env=moss_real \
|
||||
hydra.run.dir=outputs/train/act_moss_test \
|
||||
hydra.job.name=act_moss_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy=act_moss_real`. This loads configurations from [`lerobot/configs/policy/act_moss_real.yaml`](../lerobot/configs/policy/act_moss_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`.
|
||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||
--tags moss tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_moss_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_moss_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_moss_test`).
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#moss-arm`](https://discord.com/channels/1216765309076115607/1275374638985252925).
|
||||
@@ -181,8 +181,8 @@ available_real_world_datasets = [
|
||||
"lerobot/usc_cloth_sim",
|
||||
]
|
||||
|
||||
available_datasets = sorted(
|
||||
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
|
||||
available_datasets = list(
|
||||
itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/common/policies`
|
||||
|
||||
@@ -91,9 +91,9 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
)
|
||||
|
||||
if isinstance(cfg.dataset_repo_id, str):
|
||||
# TODO (aliberts): add 'episodes' arg from config after removing hydra
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
@@ -101,6 +101,7 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
else:
|
||||
dataset = MultiLeRobotDataset(
|
||||
cfg.dataset_repo_id,
|
||||
split=split,
|
||||
delta_timestamps=cfg.training.get("delta_timestamps"),
|
||||
image_transforms=image_transforms,
|
||||
video_backend=cfg.video_backend,
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
dataset = kwargs.get("dataset", None)
|
||||
image_writer = getattr(dataset, "image_writer", None) if dataset else None
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
image_writer.stop()
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ImageWriter:
|
||||
"""This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it creates a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it creates processes pool of size `num_processes`, where each subprocess starts
|
||||
their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.dir = write_dir
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = self.num_threads_per_process = num_threads
|
||||
|
||||
if self.num_processes <= 0:
|
||||
self.type = "threads"
|
||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
||||
self.futures = []
|
||||
else:
|
||||
self.type = "processes"
|
||||
self.num_threads_per_process = self.num_threads
|
||||
self.image_queue = multiprocessing.Queue()
|
||||
self.processes: list[multiprocessing.Process] = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
|
||||
def _loop_to_save_images_in_threads(self) -> None:
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
frame_data = self.image_queue.get()
|
||||
if frame_data is None:
|
||||
break
|
||||
|
||||
image, file_path = frame_data
|
||||
futures.append(executor.submit(self._save_image, image, file_path))
|
||||
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
"""Save an image asynchronously using threads or processes."""
|
||||
if self.type == "threads":
|
||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
||||
else:
|
||||
self.image_queue.put((image, file_path))
|
||||
|
||||
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
img = Image.fromarray(image.numpy())
|
||||
img.save(str(file_path), quality=100)
|
||||
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def stop(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
if self.type == "threads":
|
||||
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
|
||||
wait(self.futures, timeout=timeout)
|
||||
progress_bar.update(len(self.futures))
|
||||
else:
|
||||
self._stop_processes(timeout)
|
||||
|
||||
def _stop_processes(self, timeout) -> None:
|
||||
for _ in self.processes:
|
||||
self.image_queue.put(None)
|
||||
|
||||
for process in self.processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
self.image_queue.close()
|
||||
self.image_queue.join_thread()
|
||||
@@ -13,312 +13,63 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import datasets
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
import torch.utils
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import snapshot_download, upload_folder
|
||||
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_stats
|
||||
from lerobot.common.datasets.image_writer import ImageWriter
|
||||
from lerobot.common.datasets.compute_stats import aggregate_stats
|
||||
from lerobot.common.datasets.utils import (
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
_get_info_from_robot,
|
||||
append_jsonl,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
check_version_compatibility,
|
||||
create_branch,
|
||||
create_empty_dataset_info,
|
||||
get_delta_indices,
|
||||
get_episode_data_index,
|
||||
get_hub_safe_version,
|
||||
hf_transform_to_torch,
|
||||
load_episode_dicts,
|
||||
calculate_episode_data_index,
|
||||
load_episode_data_index,
|
||||
load_hf_dataset,
|
||||
load_info,
|
||||
load_previous_and_future_frames,
|
||||
load_stats,
|
||||
load_tasks,
|
||||
write_json,
|
||||
write_stats,
|
||||
load_videos,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, load_from_videos
|
||||
|
||||
# For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md
|
||||
CODEBASE_VERSION = "v2.0"
|
||||
LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
CODEBASE_VERSION = "v1.6"
|
||||
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
image_writer: ImageWriter | None = None,
|
||||
):
|
||||
"""LeRobotDataset encapsulates 3 main things:
|
||||
- metadata:
|
||||
- info contains various information about the dataset like shapes, keys, fps etc.
|
||||
- stats stores the dataset statistics of the different modalities for normalization
|
||||
- tasks contains the prompts for each task of the dataset, which can be used for
|
||||
task-conditionned training.
|
||||
- hf_dataset (from datasets.Dataset), which will read any values from parquet files.
|
||||
- (optional) videos from which frames are loaded to be synchronous with data from parquet files.
|
||||
|
||||
3 modes are available for this class, depending on 3 different use cases:
|
||||
|
||||
1. Your dataset already exists on the Hugging Face Hub at the address
|
||||
https://huggingface.co/datasets/{repo_id} and is not on your local disk in the 'root' folder:
|
||||
Instantiating this class with this 'repo_id' will download the dataset from that address and load
|
||||
it, pending your dataset is compliant with codebase_version v2.0. If your dataset has been created
|
||||
before this new format, you will be prompted to convert it using our conversion script from v1.6
|
||||
to v2.0, which you can find at lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py.
|
||||
|
||||
2. Your dataset already exists on your local disk in the 'root' folder:
|
||||
This is typically the case when you recorded your dataset locally and you may or may not have
|
||||
pushed it to the hub yet. Instantiating this class with 'root' will load your dataset directly
|
||||
from disk. This can happen while you're offline (no internet connection).
|
||||
|
||||
3. Your dataset doesn't already exists (either on local disk or on the Hub):
|
||||
[TODO(aliberts): add classmethod for this case?]
|
||||
|
||||
|
||||
In terms of files, a typical LeRobotDataset looks like this from its root path:
|
||||
.
|
||||
├── data
|
||||
│ ├── chunk-000
|
||||
│ │ ├── episode_000000.parquet
|
||||
│ │ ├── episode_000001.parquet
|
||||
│ │ ├── episode_000002.parquet
|
||||
│ │ └── ...
|
||||
│ ├── chunk-001
|
||||
│ │ ├── episode_001000.parquet
|
||||
│ │ ├── episode_001001.parquet
|
||||
│ │ ├── episode_001002.parquet
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── meta
|
||||
│ ├── episodes.jsonl
|
||||
│ ├── info.json
|
||||
│ ├── stats.json
|
||||
│ └── tasks.jsonl
|
||||
└── videos (optional)
|
||||
├── chunk-000
|
||||
│ ├── observation.images.laptop
|
||||
│ │ ├── episode_000000.mp4
|
||||
│ │ ├── episode_000001.mp4
|
||||
│ │ ├── episode_000002.mp4
|
||||
│ │ └── ...
|
||||
│ ├── observation.images.phone
|
||||
│ │ ├── episode_000000.mp4
|
||||
│ │ ├── episode_000001.mp4
|
||||
│ │ ├── episode_000002.mp4
|
||||
│ │ └── ...
|
||||
├── chunk-001
|
||||
└── ...
|
||||
|
||||
Note that this file-based structure is designed to be as versatile as possible. The files are split by
|
||||
episodes which allows a more granular control over which episodes one wants to use and download. The
|
||||
structure of the dataset is entirely described in the info.json file, which can be easily downloaded
|
||||
or viewed directly on the hub before downloading any actual data. The type of files used are very
|
||||
simple and do not need complex tools to be read, it only uses .parquet, .json and .mp4 files (and .md
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
torchvision.transforms.v2 here which will be applied to visual modalities (whether they come
|
||||
from videos or images). Defaults to None.
|
||||
delta_timestamps (dict[list[float]] | None, optional): _description_. Defaults to None.
|
||||
tolerance_s (float, optional): Tolerance in seconds used to ensure data timestamps are actually in
|
||||
sync with the fps value. It is used at the init of the dataset to make sure that each
|
||||
timestamps is separated to the next by 1/fps +/- tolerance_s. This also applies to frames
|
||||
decoded from video files. It is also used to check that `delta_timestamps` (when provided) are
|
||||
multiples of 1/fps. Defaults to 1e-4.
|
||||
download_videos (bool, optional): Flag to download the videos. Note that when set to True but the
|
||||
video files are already present on local disk, they won't be downloaded again. Defaults to
|
||||
True.
|
||||
video_backend (str | None, optional): Video backend to use for decoding videos. There is currently
|
||||
a single option which is the pyav decoder used by Torchvision. Defaults to pyav.
|
||||
"""
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
self.delta_indices = None
|
||||
self.local_files_only = local_files_only
|
||||
self.consolidated = True
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
self.episode_buffer = {}
|
||||
|
||||
# Load metadata
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self.pull_from_repo(allow_patterns="meta/")
|
||||
self.info = load_info(self.root)
|
||||
self.stats = load_stats(self.root)
|
||||
self.tasks = load_tasks(self.root)
|
||||
self.episode_dicts = load_episode_dicts(self.root)
|
||||
|
||||
# Check version
|
||||
check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION)
|
||||
|
||||
# Load actual data
|
||||
self.download_episodes(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
|
||||
# Check timestamps
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [X] Move delta_timestamp logic outside __get_item__
|
||||
# - [X] Update __get_item__
|
||||
# - [/] Add doc
|
||||
# - [ ] Add self.add_frame()
|
||||
# - [ ] Add self.consolidate() for:
|
||||
# - [X] Check timestamps sync
|
||||
# - [ ] Sanity checks (episodes num, shapes, files, etc.)
|
||||
# - [ ] Update episode_index (arg update=True)
|
||||
# - [ ] Update info.json (arg update=True)
|
||||
|
||||
@cached_property
|
||||
def _hub_version(self) -> str | None:
|
||||
return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION)
|
||||
|
||||
@property
|
||||
def _version(self) -> str:
|
||||
"""Codebase version used to create this dataset."""
|
||||
return self.info["codebase_version"]
|
||||
|
||||
def push_to_hub(self, push_videos: bool = True) -> None:
|
||||
if not self.consolidated:
|
||||
raise RuntimeError(
|
||||
"You are trying to upload to the hub a LeRobotDataset that has not been consolidated yet."
|
||||
"Please call the dataset 'consolidate()' method first."
|
||||
)
|
||||
ignore_patterns = ["images/"]
|
||||
if not push_videos:
|
||||
ignore_patterns.append("videos/")
|
||||
|
||||
upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.root,
|
||||
repo_type="dataset",
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset")
|
||||
|
||||
def pull_from_repo(
|
||||
self,
|
||||
allow_patterns: list[str] | str | None = None,
|
||||
ignore_patterns: list[str] | str | None = None,
|
||||
) -> None:
|
||||
snapshot_download(
|
||||
self.repo_id,
|
||||
repo_type="dataset",
|
||||
revision=self._hub_version,
|
||||
local_dir=self.root,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
local_files_only=self.local_files_only,
|
||||
)
|
||||
|
||||
def download_episodes(self, download_videos: bool = True) -> None:
|
||||
"""Downloads the dataset from the given 'repo_id' at the provided version. If 'episodes' is given, this
|
||||
will only download those episodes (selected by their episode_index). If 'episodes' is None, the whole
|
||||
dataset will be downloaded. Thanks to the behavior of snapshot_download, if the files are already present
|
||||
in 'local_dir', they won't be downloaded again.
|
||||
"""
|
||||
# load data from hub or locally when root is provided
|
||||
# TODO(rcadene, aliberts): implement faster transfer
|
||||
# https://huggingface.co/docs/huggingface_hub/en/guides/download#faster-downloads
|
||||
files = None
|
||||
ignore_patterns = None if download_videos else "videos/"
|
||||
if self.episodes is not None:
|
||||
files = [str(self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
if len(self.video_keys) > 0 and download_videos:
|
||||
video_files = [
|
||||
str(self.get_video_file_path(ep_idx, vid_key))
|
||||
for vid_key in self.video_keys
|
||||
for ep_idx in self.episodes
|
||||
]
|
||||
files += video_files
|
||||
|
||||
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
|
||||
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if self.episodes is None:
|
||||
path = str(self.root / "data")
|
||||
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
||||
self.hf_dataset = load_hf_dataset(repo_id, CODEBASE_VERSION, root, split)
|
||||
if split == "train":
|
||||
self.episode_data_index = load_episode_data_index(repo_id, CODEBASE_VERSION, root)
|
||||
else:
|
||||
files = [str(self.root / self.get_data_file_path(ep_idx)) for ep_idx in self.episodes]
|
||||
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
def get_data_file_path(self, ep_index: int) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.data_path.format(episode_chunk=ep_chunk, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_video_file_path(self, ep_index: int, vid_key: str) -> Path:
|
||||
ep_chunk = self.get_episode_chunk(ep_index)
|
||||
fpath = self.videos_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index)
|
||||
return Path(fpath)
|
||||
|
||||
def get_episode_chunk(self, ep_index: int) -> int:
|
||||
return ep_index // self.chunks_size
|
||||
|
||||
@property
|
||||
def data_path(self) -> str:
|
||||
"""Formattable string for the parquet files."""
|
||||
return self.info["data_path"]
|
||||
|
||||
@property
|
||||
def videos_path(self) -> str | None:
|
||||
"""Formattable string for the video files."""
|
||||
return self.info["videos"]["videos_path"] if len(self.video_keys) > 0 else None
|
||||
self.episode_data_index = calculate_episode_data_index(self.hf_dataset)
|
||||
self.hf_dataset = reset_episode_index(self.hf_dataset)
|
||||
self.stats = load_stats(repo_id, CODEBASE_VERSION, root)
|
||||
self.info = load_info(repo_id, CODEBASE_VERSION, root)
|
||||
if self.video:
|
||||
self.videos_dir = load_videos(repo_id, CODEBASE_VERSION, root)
|
||||
self.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
|
||||
@property
|
||||
def fps(self) -> int:
|
||||
@@ -326,495 +77,140 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return self.info["fps"]
|
||||
|
||||
@property
|
||||
def keys(self) -> list[str]:
|
||||
"""Keys to access non-image data (state, actions etc.)."""
|
||||
return self.info["keys"]
|
||||
def video(self) -> bool:
|
||||
"""Returns True if this dataset loads video frames from mp4 files.
|
||||
Returns False if it only loads images from png files.
|
||||
"""
|
||||
return self.info.get("video", False)
|
||||
|
||||
@property
|
||||
def image_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as images."""
|
||||
return self.info["image_keys"]
|
||||
|
||||
@property
|
||||
def video_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return self.info["video_keys"]
|
||||
def features(self) -> datasets.Features:
|
||||
return self.hf_dataset.features
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
return self.image_keys + self.video_keys
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
@property
|
||||
def names(self) -> dict[list[str]]:
|
||||
"""Names of the various dimensions of vector modalities."""
|
||||
return self.info["names"]
|
||||
def video_frame_keys(self) -> list[str]:
|
||||
"""Keys to access video frames that requires to be decoded into images.
|
||||
|
||||
Note: It is empty if the dataset contains images only,
|
||||
or equal to `self.cameras` if the dataset contains videos only,
|
||||
or can even be a subset of `self.cameras` in a case of a mixed image/video dataset.
|
||||
"""
|
||||
video_frame_keys = []
|
||||
for key, feats in self.hf_dataset.features.items():
|
||||
if isinstance(feats, VideoFrame):
|
||||
video_frame_keys.append(key)
|
||||
return video_frame_keys
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
"""Number of samples/frames in selected episodes."""
|
||||
"""Number of samples/frames."""
|
||||
return len(self.hf_dataset)
|
||||
|
||||
@property
|
||||
def num_episodes(self) -> int:
|
||||
"""Number of episodes selected."""
|
||||
return len(self.episodes) if self.episodes is not None else self.total_episodes
|
||||
"""Number of episodes."""
|
||||
return len(self.hf_dataset.unique("episode_index"))
|
||||
|
||||
@property
|
||||
def total_episodes(self) -> int:
|
||||
"""Total number of episodes available."""
|
||||
return self.info["total_episodes"]
|
||||
|
||||
@property
|
||||
def total_frames(self) -> int:
|
||||
"""Total number of frames saved in this dataset."""
|
||||
return self.info["total_frames"]
|
||||
|
||||
@property
|
||||
def total_tasks(self) -> int:
|
||||
"""Total number of different tasks performed in this dataset."""
|
||||
return self.info["total_tasks"]
|
||||
|
||||
@property
|
||||
def total_chunks(self) -> int:
|
||||
"""Total number of chunks (groups of episodes)."""
|
||||
return self.info["total_chunks"]
|
||||
|
||||
@property
|
||||
def chunks_size(self) -> int:
|
||||
"""Max number of episodes per chunk."""
|
||||
return self.info["chunks_size"]
|
||||
|
||||
@property
|
||||
def shapes(self) -> dict:
|
||||
"""Shapes for the different features."""
|
||||
return self.info["shapes"]
|
||||
|
||||
@property
|
||||
def features(self) -> datasets.Features:
|
||||
"""Features of the hf_dataset."""
|
||||
if self.hf_dataset is not None:
|
||||
return self.hf_dataset.features
|
||||
elif self.episode_buffer is None:
|
||||
raise NotImplementedError(
|
||||
"Dataset features must be infered from an existing hf_dataset or episode_buffer."
|
||||
)
|
||||
|
||||
features = {}
|
||||
for key in self.episode_buffer:
|
||||
if key in ["episode_index", "frame_index", "index", "task_index"]:
|
||||
features[key] = datasets.Value(dtype="int64")
|
||||
elif key in ["next.done", "next.success"]:
|
||||
features[key] = datasets.Value(dtype="bool")
|
||||
elif key in ["timestamp", "next.reward"]:
|
||||
features[key] = datasets.Value(dtype="float32")
|
||||
elif key in self.image_keys:
|
||||
features[key] = datasets.Image()
|
||||
elif key in self.keys:
|
||||
features[key] = datasets.Sequence(
|
||||
length=self.shapes[key], feature=datasets.Value(dtype="float32")
|
||||
)
|
||||
|
||||
return datasets.Features(features)
|
||||
|
||||
@property
|
||||
def task_to_task_index(self) -> dict:
|
||||
return {task: task_idx for task_idx, task in self.tasks.items()}
|
||||
|
||||
def get_task_index(self, task: str) -> int:
|
||||
def tolerance_s(self) -> float:
|
||||
"""Tolerance in seconds used to discard loaded frames when their timestamps
|
||||
are not close enough from the requested frames. It is only used when `delta_timestamps`
|
||||
is provided or when loading video frames from mp4 files.
|
||||
"""
|
||||
Given a task in natural language, returns its task_index if the task already exists in the dataset,
|
||||
otherwise creates a new task_index.
|
||||
"""
|
||||
task_index = self.task_to_task_index.get(task, None)
|
||||
return task_index if task_index is not None else self.total_tasks
|
||||
|
||||
def current_episode_index(self, idx: int) -> int:
|
||||
episode_index = self.hf_dataset["episode_index"][idx]
|
||||
if self.episodes is not None:
|
||||
# get episode_index from selected episodes
|
||||
episode_index = self.episodes.index(episode_index)
|
||||
|
||||
return episode_index
|
||||
|
||||
def episode_length(self, episode_index) -> int:
|
||||
"""Number of samples/frames for given episode."""
|
||||
return self.info["episodes"][episode_index]["length"]
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
ep_start = self.episode_data_index["from"][ep_idx]
|
||||
ep_end = self.episode_data_index["to"][ep_idx]
|
||||
query_indices = {
|
||||
key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
return query_indices, padding
|
||||
|
||||
def _get_query_timestamps(
|
||||
self,
|
||||
current_ts: float,
|
||||
query_indices: dict[str, list[int]] | None = None,
|
||||
) -> dict[str, list[float]]:
|
||||
query_timestamps = {}
|
||||
for key in self.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.video_keys
|
||||
}
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
||||
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
||||
the main process and a subprocess fails to access it.
|
||||
"""
|
||||
item = {}
|
||||
for vid_key, query_ts in query_timestamps.items():
|
||||
video_path = self.root / self.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames_torchvision(
|
||||
video_path, query_ts, self.tolerance_s, self.video_backend
|
||||
)
|
||||
item[vid_key] = frames.squeeze(0)
|
||||
|
||||
return item
|
||||
|
||||
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
|
||||
for key, val in padding.items():
|
||||
item[key] = torch.BoolTensor(val)
|
||||
return item
|
||||
# 1e-4 to account for possible numerical error
|
||||
return 1 / self.fps - 1e-4
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx) -> dict:
|
||||
def __getitem__(self, idx):
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx
|
||||
query_indices, padding = self._get_query_indices(idx, current_ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
item[key] = val
|
||||
if self.delta_timestamps is not None:
|
||||
item = load_previous_and_future_frames(
|
||||
item,
|
||||
self.hf_dataset,
|
||||
self.episode_data_index,
|
||||
self.delta_timestamps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
current_ts = item["timestamp"].item()
|
||||
query_timestamps = self._get_query_timestamps(current_ts, query_indices)
|
||||
video_frames = self._query_videos(query_timestamps, ep_idx)
|
||||
item = {**video_frames, **item}
|
||||
if self.video:
|
||||
item = load_from_videos(
|
||||
item,
|
||||
self.video_frame_keys,
|
||||
self.videos_dir,
|
||||
self.tolerance_s,
|
||||
self.video_backend,
|
||||
)
|
||||
|
||||
if self.image_transforms is not None:
|
||||
image_keys = self.camera_keys
|
||||
for cam in image_keys:
|
||||
for cam in self.camera_keys:
|
||||
item[cam] = self.image_transforms(item[cam])
|
||||
|
||||
return item
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}\n"
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository ID: '{self.repo_id}',\n"
|
||||
f" Selected episodes: {self.episodes},\n"
|
||||
f" Number of selected episodes: {self.num_episodes},\n"
|
||||
f" Number of selected samples: {self.num_samples},\n"
|
||||
f"\n{json.dumps(self.info, indent=4)}\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
f" Recorded Frames per Second: {self.fps},\n"
|
||||
f" Camera Keys: {self.camera_keys},\n"
|
||||
f" Video Frame Keys: {self.video_frame_keys if self.video else 'N/A'},\n"
|
||||
f" Transformations: {self.image_transforms},\n"
|
||||
f" Codebase Version: {self.info.get('codebase_version', '< v1.6')},\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
def _create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
# TODO(aliberts): Handle resume
|
||||
return {
|
||||
"size": 0,
|
||||
"episode_index": self.total_episodes if episode_index is None else episode_index,
|
||||
"task_index": None,
|
||||
"frame_index": [],
|
||||
"timestamp": [],
|
||||
"next.done": [],
|
||||
**{key: [] for key in self.keys},
|
||||
**{key: [] for key in self.image_keys},
|
||||
}
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
|
||||
temporary directory — nothing is written to disk. To save those frames, the 'add_episode()' method
|
||||
then needs to be called.
|
||||
"""
|
||||
frame_index = self.episode_buffer["size"]
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(frame_index / self.fps)
|
||||
self.episode_buffer["next.done"].append(False)
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in self.keys:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
if self.image_writer is None:
|
||||
return
|
||||
|
||||
# Save images
|
||||
for cam_key in self.camera_keys:
|
||||
img_path = self.image_writer.get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=cam_key, frame_index=frame_index
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.image_writer.async_save_image(
|
||||
image=frame[cam_key],
|
||||
file_path=img_path,
|
||||
)
|
||||
if cam_key in self.image_keys:
|
||||
self.episode_buffer[cam_key].append(str(img_path))
|
||||
|
||||
def add_episode(self, task: str, encode_videos: bool = False) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer. Note that since it affects files on
|
||||
disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to
|
||||
the hub.
|
||||
|
||||
Use 'encode_videos' if you want to encode videos during the saving of each episode. Otherwise,
|
||||
you can do it later with dataset.consolidate(). This is to give more flexibility on when to spend
|
||||
time for video encoding.
|
||||
"""
|
||||
episode_length = self.episode_buffer.pop("size")
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if episode_index != self.total_episodes:
|
||||
# TODO(aliberts): Add option to use existing episode_index
|
||||
raise NotImplementedError()
|
||||
|
||||
task_index = self.get_task_index(task)
|
||||
self.episode_buffer["next.done"][-1] = True
|
||||
|
||||
for key in self.episode_buffer:
|
||||
if key in self.image_keys:
|
||||
continue
|
||||
if key in self.keys:
|
||||
self.episode_buffer[key] = torch.stack(self.episode_buffer[key])
|
||||
elif key == "episode_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), episode_index)
|
||||
elif key == "task_index":
|
||||
self.episode_buffer[key] = torch.full((episode_length,), task_index)
|
||||
else:
|
||||
self.episode_buffer[key] = torch.tensor(self.episode_buffer[key])
|
||||
|
||||
self.episode_buffer["index"] = torch.arange(self.total_frames, self.total_frames + episode_length)
|
||||
self._save_episode_to_metadata(episode_index, episode_length, task, task_index)
|
||||
self._save_episode_table(episode_index)
|
||||
|
||||
if encode_videos and len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
self.consolidated = False
|
||||
|
||||
def _save_episode_table(self, episode_index: int) -> None:
|
||||
ep_dataset = datasets.Dataset.from_dict(self.episode_buffer, features=self.features, split="train")
|
||||
ep_table = ep_dataset._data.table
|
||||
ep_data_path = self.root / self.get_data_file_path(ep_index=episode_index)
|
||||
ep_data_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pq.write_table(ep_table, ep_data_path)
|
||||
|
||||
def _save_episode_to_metadata(
|
||||
self, episode_index: int, episode_length: int, task: str, task_index: int
|
||||
) -> None:
|
||||
self.info["total_episodes"] += 1
|
||||
self.info["total_frames"] += episode_length
|
||||
|
||||
if task_index not in self.tasks:
|
||||
self.info["total_tasks"] += 1
|
||||
self.tasks[task_index] = task
|
||||
task_dict = {
|
||||
"task_index": task_index,
|
||||
"task": task,
|
||||
}
|
||||
append_jsonl(task_dict, self.root / TASKS_PATH)
|
||||
|
||||
chunk = self.get_episode_chunk(episode_index)
|
||||
if chunk >= self.total_chunks:
|
||||
self.info["total_chunks"] += 1
|
||||
|
||||
self.info["splits"] = {"train": f"0:{self.info['total_episodes']}"}
|
||||
self.info["total_videos"] += len(self.video_keys)
|
||||
write_json(self.info, self.root / INFO_PATH)
|
||||
|
||||
episode_dict = {
|
||||
"episode_index": episode_index,
|
||||
"tasks": [task],
|
||||
"length": episode_length,
|
||||
}
|
||||
self.episode_dicts.append(episode_dict)
|
||||
append_jsonl(episode_dict, self.root / EPISODES_PATH)
|
||||
|
||||
def clear_episode_buffer(self) -> None:
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if self.image_writer is not None:
|
||||
for cam_key in self.camera_keys:
|
||||
img_dir = self.image_writer.get_episode_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
# Reset the buffer
|
||||
self.episode_buffer = self._create_episode_buffer()
|
||||
|
||||
def start_image_writter(self, num_processes: int = 0, num_threads: int = 1) -> None:
|
||||
if isinstance(self.image_writer, ImageWriter):
|
||||
logging.warning(
|
||||
"You are starting a new ImageWriter that is replacing an already exising one in the dataset."
|
||||
)
|
||||
|
||||
self.image_writer = ImageWriter(
|
||||
write_dir=self.root / "images",
|
||||
num_processes=num_processes,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
def stop_image_writter(self) -> None:
|
||||
"""
|
||||
Whenever wrapping this dataset inside a parallelized DataLoader, this needs to be called first to
|
||||
remove the image_write in order for the LeRobotDataset object to be pickleable and parallelized.
|
||||
"""
|
||||
if self.image_writer is not None:
|
||||
self.image_writer.stop()
|
||||
self.image_writer = None
|
||||
|
||||
def encode_videos(self) -> None:
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in range(self.num_episodes):
|
||||
for key in self.video_keys:
|
||||
# TODO: create video_buffer to store the state of encoded/unencoded videos and remove the need
|
||||
# to call self.image_writer here
|
||||
tmp_imgs_dir = self.image_writer.get_episode_dir(episode_index, key)
|
||||
video_path = self.root / self.get_video_file_path(episode_index, key)
|
||||
if video_path.is_file():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, self.fps, overwrite=True)
|
||||
|
||||
def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None:
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
self.episode_data_index = get_episode_data_index(self.episodes, self.episode_dicts)
|
||||
check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s)
|
||||
|
||||
if len(self.video_keys) > 0:
|
||||
self.encode_videos()
|
||||
|
||||
if not keep_image_files and self.image_writer is not None:
|
||||
shutil.rmtree(self.image_writer.dir)
|
||||
|
||||
if run_compute_stats:
|
||||
self.stop_image_writter()
|
||||
self.stats = compute_stats(self)
|
||||
write_stats(self.stats, self.root / STATS_PATH)
|
||||
self.consolidated = True
|
||||
else:
|
||||
logging.warning(
|
||||
"Skipping computation of the dataset statistics, dataset is not fully consolidated."
|
||||
)
|
||||
|
||||
# TODO(aliberts)
|
||||
# - [ ] add video info in info.json
|
||||
# Sanity checks:
|
||||
# - [ ] shapes
|
||||
# - [ ] ep_lenghts
|
||||
# - [ ] number of files
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
def from_preloaded(
|
||||
cls,
|
||||
repo_id: str,
|
||||
fps: int,
|
||||
repo_id: str = "from_preloaded",
|
||||
root: Path | None = None,
|
||||
robot: Robot | None = None,
|
||||
robot_type: str | None = None,
|
||||
keys: list[str] | None = None,
|
||||
image_keys: list[str] | None = None,
|
||||
video_keys: list[str] = None,
|
||||
shapes: dict | None = None,
|
||||
names: dict | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads_per_camera: int = 0,
|
||||
use_videos: bool = True,
|
||||
video_backend: str | None = None,
|
||||
split: str = "train",
|
||||
transform: callable = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
# additional preloaded attributes
|
||||
hf_dataset=None,
|
||||
episode_data_index=None,
|
||||
stats=None,
|
||||
info=None,
|
||||
videos_dir=None,
|
||||
video_backend=None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
"""Create a LeRobot Dataset from existing data and attributes instead of loading from the filesystem.
|
||||
|
||||
It is especially useful when converting raw data into LeRobotDataset before saving the dataset
|
||||
on the filesystem or uploading to the hub.
|
||||
|
||||
Note: Meta-data attributes like `repo_id`, `version`, `root`, etc are optional and potentially
|
||||
meaningless depending on the downstream usage of the return dataset.
|
||||
"""
|
||||
# create an empty object of type LeRobotDataset
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj.root = root if root is not None else LEROBOT_HOME / repo_id
|
||||
obj.tolerance_s = tolerance_s
|
||||
obj.image_writer = None
|
||||
|
||||
if robot is not None:
|
||||
robot_type, keys, image_keys, video_keys, shapes, names = _get_info_from_robot(robot, use_videos)
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
f"Some cameras in your {robot.robot_type} robot don't have an fps matching the fps of your dataset."
|
||||
"In this case, frames from lower fps cameras will be repeated to fill in the blanks"
|
||||
)
|
||||
if len(robot.cameras) > 0 and (image_writer_processes or image_writer_threads_per_camera):
|
||||
obj.start_image_writter(
|
||||
image_writer_processes, image_writer_threads_per_camera * robot.num_cameras
|
||||
)
|
||||
elif (
|
||||
robot_type is None
|
||||
or keys is None
|
||||
or image_keys is None
|
||||
or video_keys is None
|
||||
or shapes is None
|
||||
or names is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Dataset info (robot_type, keys, shapes...) must either come from a Robot or explicitly passed upon creation."
|
||||
)
|
||||
|
||||
if len(video_keys) > 0 and not use_videos:
|
||||
raise ValueError
|
||||
|
||||
obj.tasks, obj.stats, obj.episode_dicts = {}, {}, []
|
||||
obj.info = create_empty_dataset_info(
|
||||
CODEBASE_VERSION, fps, robot_type, keys, image_keys, video_keys, shapes, names
|
||||
)
|
||||
write_json(obj.info, obj.root / INFO_PATH)
|
||||
|
||||
# TODO(aliberts, rcadene, alexander-soare): Merge this with OnlineBuffer/DataBuffer
|
||||
obj.episode_buffer = obj._create_episode_buffer()
|
||||
|
||||
# This bool indicates that the current LeRobotDataset instance is in sync with the files on disk. It
|
||||
# is used to know when certain operations are need (for instance, computing dataset statistics). In
|
||||
# order to be able to push the dataset to the hub, it needs to be consolidated first by calling
|
||||
# self.consolidate().
|
||||
obj.consolidated = True
|
||||
|
||||
obj.episodes = None
|
||||
obj.hf_dataset = None
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj.local_files_only = True
|
||||
obj.episode_data_index = None
|
||||
obj.root = root
|
||||
obj.split = split
|
||||
obj.image_transforms = transform
|
||||
obj.delta_timestamps = delta_timestamps
|
||||
obj.hf_dataset = hf_dataset
|
||||
obj.episode_data_index = episode_data_index
|
||||
obj.stats = stats
|
||||
obj.info = info if info is not None else {}
|
||||
obj.videos_dir = videos_dir
|
||||
obj.video_backend = video_backend if video_backend is not None else "pyav"
|
||||
return obj
|
||||
|
||||
@@ -829,8 +225,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_ids: list[str],
|
||||
root: Path | None = None,
|
||||
episodes: dict | None = None,
|
||||
root: Path | None = DATA_DIR,
|
||||
split: str = "train",
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
video_backend: str | None = None,
|
||||
@@ -842,8 +238,8 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self._datasets = [
|
||||
LeRobotDataset(
|
||||
repo_id,
|
||||
root=root / repo_id if root is not None else None,
|
||||
episodes=episodes[repo_id] if episodes is not None else None,
|
||||
root=root,
|
||||
split=split,
|
||||
delta_timestamps=delta_timestamps,
|
||||
image_transforms=image_transforms,
|
||||
video_backend=video_backend,
|
||||
@@ -879,6 +275,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
self.disabled_data_keys.update(extra_keys)
|
||||
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.stats = aggregate_stats(self._datasets)
|
||||
@@ -992,6 +389,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
return (
|
||||
f"{self.__class__.__name__}(\n"
|
||||
f" Repository IDs: '{self.repo_ids}',\n"
|
||||
f" Split: '{self.split}',\n"
|
||||
f" Number of Samples: {self.num_samples},\n"
|
||||
f" Number of Episodes: {self.num_episodes},\n"
|
||||
f" Type: {'video (.mp4)' if self.video else 'image (.png)'},\n"
|
||||
|
||||
468
lerobot/common/datasets/populate_dataset.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Functions to create an empty dataset, and populate it with frames."""
|
||||
# TODO(rcadene, aliberts): to adapt as class methods of next version of LeRobotDataset
|
||||
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.compute_stats import compute_stats
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import to_hf_dataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes, get_default_encoding
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, create_branch
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
from lerobot.common.utils.utils import log_say
|
||||
from lerobot.scripts.push_dataset_to_hub import (
|
||||
push_dataset_card_to_hub,
|
||||
push_meta_data_to_hub,
|
||||
push_videos_to_hub,
|
||||
save_meta_data,
|
||||
)
|
||||
|
||||
########################################################################################
|
||||
# Asynchrounous saving of images on disk
|
||||
########################################################################################
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
# TODO(aliberts): Allow to pass custom exceptions
|
||||
# (e.g. ThreadServiceExit, KeyboardInterrupt, SystemExit, UnpluggedError, DynamixelCommError)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
image_writer = kwargs.get("dataset", {}).get("image_writer")
|
||||
if image_writer is not None:
|
||||
print("Waiting for image writer to terminate...")
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def save_image(img_tensor, key, frame_index, episode_index, videos_dir: str):
|
||||
img = Image.fromarray(img_tensor.numpy())
|
||||
path = Path(videos_dir) / f"{key}_episode_{episode_index:06d}" / f"frame_{frame_index:06d}.png"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
img.save(str(path), quality=100)
|
||||
|
||||
|
||||
def loop_to_save_images_in_threads(image_queue, num_threads):
|
||||
if num_threads < 1:
|
||||
raise NotImplementedError(f"Only `num_threads>=1` is supported for now, but {num_threads=} given.")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
# Blocks until a frame is available
|
||||
frame_data = image_queue.get()
|
||||
|
||||
# As usually done, exit loop when receiving None to stop the worker
|
||||
if frame_data is None:
|
||||
break
|
||||
|
||||
image, key, frame_index, episode_index, videos_dir = frame_data
|
||||
futures.append(executor.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
|
||||
|
||||
def start_image_writer_processes(image_queue, num_processes, num_threads_per_process):
|
||||
if num_processes < 1:
|
||||
raise ValueError(f"Only `num_processes>=1` is supported, but {num_processes=} given.")
|
||||
|
||||
if num_threads_per_process < 1:
|
||||
raise NotImplementedError(
|
||||
"Only `num_threads_per_process>=1` is supported for now, but {num_threads_per_process=} given."
|
||||
)
|
||||
|
||||
processes = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(
|
||||
target=loop_to_save_images_in_threads,
|
||||
args=(image_queue, num_threads_per_process),
|
||||
)
|
||||
process.start()
|
||||
processes.append(process)
|
||||
return processes
|
||||
|
||||
|
||||
def stop_processes(processes, queue, timeout):
|
||||
# Send None to each process to signal them to stop
|
||||
for _ in processes:
|
||||
queue.put(None)
|
||||
|
||||
# Wait maximum 20 seconds for all processes to terminate
|
||||
for process in processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
# If not terminated after 20 seconds, force termination
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
# Close the queue, no more items can be put in the queue
|
||||
queue.close()
|
||||
|
||||
# Ensure all background queue threads have finished
|
||||
queue.join_thread()
|
||||
|
||||
|
||||
def start_image_writer(num_processes, num_threads):
|
||||
"""This function abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
When `num_processes=0`, it returns a dictionary containing a threads pool of size `num_threads`.
|
||||
When `num_processes>0`, it returns a dictionary containing a processes pool of size `num_processes`,
|
||||
where each subprocess starts their own threads pool of size `num_threads`.
|
||||
|
||||
The optimal number of processes and threads depends on your computer capabilities.
|
||||
We advise to use 4 threads per camera with 0 processes. If the fps is not stable, try to increase or lower
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
image_writer = {}
|
||||
|
||||
if num_processes == 0:
|
||||
futures = []
|
||||
threads_pool = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads)
|
||||
image_writer["threads_pool"], image_writer["futures"] = threads_pool, futures
|
||||
else:
|
||||
# TODO(rcadene): When using num_processes>1, `multiprocessing.Manager().Queue()`
|
||||
# might be better than `multiprocessing.Queue()`. Source: https://www.geeksforgeeks.org/python-multiprocessing-queue-vs-multiprocessing-manager-queue
|
||||
image_queue = multiprocessing.Queue()
|
||||
processes_pool = start_image_writer_processes(
|
||||
image_queue, num_processes=num_processes, num_threads_per_process=num_threads
|
||||
)
|
||||
image_writer["processes_pool"], image_writer["image_queue"] = processes_pool, image_queue
|
||||
|
||||
return image_writer
|
||||
|
||||
|
||||
def async_save_image(image_writer, image, key, frame_index, episode_index, videos_dir):
|
||||
"""This function abstract away the saving of an image on disk asynchrounously. It uses a dictionary
|
||||
called image writer which contains either a pool of processes or a pool of threads.
|
||||
"""
|
||||
if "threads_pool" in image_writer:
|
||||
threads_pool, futures = image_writer["threads_pool"], image_writer["futures"]
|
||||
futures.append(threads_pool.submit(save_image, image, key, frame_index, episode_index, videos_dir))
|
||||
else:
|
||||
image_queue = image_writer["image_queue"]
|
||||
image_queue.put((image, key, frame_index, episode_index, videos_dir))
|
||||
|
||||
|
||||
def stop_image_writer(image_writer, timeout):
|
||||
if "threads_pool" in image_writer:
|
||||
futures = image_writer["futures"]
|
||||
# Before exiting function, wait for all threads to complete
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
concurrent.futures.wait(futures, timeout=timeout)
|
||||
progress_bar.update(len(futures))
|
||||
else:
|
||||
processes_pool, image_queue = image_writer["processes_pool"], image_writer["image_queue"]
|
||||
stop_processes(processes_pool, image_queue, timeout=timeout)
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Functions to initialize, resume and populate a dataset
|
||||
########################################################################################
|
||||
|
||||
|
||||
def init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
video,
|
||||
write_images,
|
||||
num_image_writer_processes,
|
||||
num_image_writer_threads,
|
||||
):
|
||||
local_dir = Path(root) / repo_id
|
||||
if local_dir.exists() and force_override:
|
||||
shutil.rmtree(local_dir)
|
||||
|
||||
episodes_dir = local_dir / "episodes"
|
||||
episodes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
videos_dir = local_dir / "videos"
|
||||
videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Logic to resume data recording
|
||||
rec_info_path = episodes_dir / "data_recording_info.json"
|
||||
if rec_info_path.exists():
|
||||
with open(rec_info_path) as f:
|
||||
rec_info = json.load(f)
|
||||
num_episodes = rec_info["last_episode_index"] + 1
|
||||
else:
|
||||
num_episodes = 0
|
||||
|
||||
dataset = {
|
||||
"repo_id": repo_id,
|
||||
"local_dir": local_dir,
|
||||
"videos_dir": videos_dir,
|
||||
"episodes_dir": episodes_dir,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
"rec_info_path": rec_info_path,
|
||||
"num_episodes": num_episodes,
|
||||
}
|
||||
|
||||
if write_images:
|
||||
# Initialize processes or/and threads dedicated to save images on disk asynchronously,
|
||||
# which is critical to control a robot and record data at a high frame rate.
|
||||
image_writer = start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads,
|
||||
)
|
||||
dataset["image_writer"] = image_writer
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def add_frame(dataset, observation, action):
|
||||
if "current_episode" not in dataset:
|
||||
# initialize episode dictionary
|
||||
ep_dict = {}
|
||||
for key in observation:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
for key in action:
|
||||
if key not in ep_dict:
|
||||
ep_dict[key] = []
|
||||
|
||||
ep_dict["episode_index"] = []
|
||||
ep_dict["frame_index"] = []
|
||||
ep_dict["timestamp"] = []
|
||||
ep_dict["next.done"] = []
|
||||
|
||||
dataset["current_episode"] = ep_dict
|
||||
dataset["current_frame_index"] = 0
|
||||
|
||||
ep_dict = dataset["current_episode"]
|
||||
episode_index = dataset["num_episodes"]
|
||||
frame_index = dataset["current_frame_index"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
ep_dict["episode_index"].append(episode_index)
|
||||
ep_dict["frame_index"].append(frame_index)
|
||||
ep_dict["timestamp"].append(frame_index / fps)
|
||||
ep_dict["next.done"].append(False)
|
||||
|
||||
img_keys = [key for key in observation if "image" in key]
|
||||
non_img_keys = [key for key in observation if "image" not in key]
|
||||
|
||||
# Save all observed modalities except images
|
||||
for key in non_img_keys:
|
||||
ep_dict[key].append(observation[key])
|
||||
|
||||
# Save actions
|
||||
for key in action:
|
||||
ep_dict[key].append(action[key])
|
||||
|
||||
if "image_writer" not in dataset:
|
||||
dataset["current_frame_index"] += 1
|
||||
return
|
||||
|
||||
# Save images
|
||||
image_writer = dataset["image_writer"]
|
||||
for key in img_keys:
|
||||
imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
async_save_image(
|
||||
image_writer,
|
||||
image=observation[key],
|
||||
key=key,
|
||||
frame_index=frame_index,
|
||||
episode_index=episode_index,
|
||||
videos_dir=str(videos_dir),
|
||||
)
|
||||
|
||||
if video:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
frame_info = {"path": f"videos/{fname}", "timestamp": frame_index / fps}
|
||||
else:
|
||||
frame_info = str(imgs_dir / f"frame_{frame_index:06d}.png")
|
||||
|
||||
ep_dict[key].append(frame_info)
|
||||
|
||||
dataset["current_frame_index"] += 1
|
||||
|
||||
|
||||
def delete_current_episode(dataset):
|
||||
del dataset["current_episode"]
|
||||
del dataset["current_frame_index"]
|
||||
|
||||
# delete temporary images
|
||||
episode_index = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
for tmp_imgs_dir in videos_dir.glob(f"*_episode_{episode_index:06d}"):
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def save_current_episode(dataset):
|
||||
episode_index = dataset["num_episodes"]
|
||||
ep_dict = dataset["current_episode"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
rec_info_path = dataset["rec_info_path"]
|
||||
|
||||
ep_dict["next.done"][-1] = True
|
||||
|
||||
for key in ep_dict:
|
||||
if "observation" in key and "image" not in key:
|
||||
ep_dict[key] = torch.stack(ep_dict[key])
|
||||
|
||||
ep_dict["action"] = torch.stack(ep_dict["action"])
|
||||
ep_dict["episode_index"] = torch.tensor(ep_dict["episode_index"])
|
||||
ep_dict["frame_index"] = torch.tensor(ep_dict["frame_index"])
|
||||
ep_dict["timestamp"] = torch.tensor(ep_dict["timestamp"])
|
||||
ep_dict["next.done"] = torch.tensor(ep_dict["next.done"])
|
||||
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
torch.save(ep_dict, ep_path)
|
||||
|
||||
rec_info = {
|
||||
"last_episode_index": episode_index,
|
||||
}
|
||||
with open(rec_info_path, "w") as f:
|
||||
json.dump(rec_info, f)
|
||||
|
||||
# force re-initialization of episode dictionnary during add_frame
|
||||
del dataset["current_episode"]
|
||||
|
||||
dataset["num_episodes"] += 1
|
||||
|
||||
|
||||
def encode_videos(dataset, image_keys, play_sounds):
|
||||
log_say("Encoding videos", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
local_dir = dataset["local_dir"]
|
||||
fps = dataset["fps"]
|
||||
|
||||
# Use ffmpeg to convert frames stored as png into mp4 videos
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
for key in image_keys:
|
||||
# key = f"observation.images.{name}"
|
||||
tmp_imgs_dir = videos_dir / f"{key}_episode_{episode_index:06d}"
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
video_path = local_dir / "videos" / fname
|
||||
if video_path.exists():
|
||||
# Skip if video is already encoded. Could be the case when resuming data recording.
|
||||
continue
|
||||
# note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
# since video encoding with ffmpeg is already using multithreading.
|
||||
encode_video_frames(tmp_imgs_dir, video_path, fps, overwrite=True)
|
||||
shutil.rmtree(tmp_imgs_dir)
|
||||
|
||||
|
||||
def from_dataset_to_lerobot_dataset(dataset, play_sounds):
|
||||
log_say("Consolidate episodes", play_sounds)
|
||||
|
||||
num_episodes = dataset["num_episodes"]
|
||||
episodes_dir = dataset["episodes_dir"]
|
||||
videos_dir = dataset["videos_dir"]
|
||||
video = dataset["video"]
|
||||
fps = dataset["fps"]
|
||||
repo_id = dataset["repo_id"]
|
||||
|
||||
ep_dicts = []
|
||||
for episode_index in tqdm.tqdm(range(num_episodes)):
|
||||
ep_path = episodes_dir / f"episode_{episode_index}.pth"
|
||||
ep_dict = torch.load(ep_path)
|
||||
ep_dicts.append(ep_dict)
|
||||
data_dict = concatenate_episodes(ep_dicts)
|
||||
|
||||
if video:
|
||||
image_keys = [key for key in data_dict if "image" in key]
|
||||
encode_videos(dataset, image_keys, play_sounds)
|
||||
|
||||
hf_dataset = to_hf_dataset(data_dict, video)
|
||||
episode_data_index = calculate_episode_data_index(hf_dataset)
|
||||
|
||||
info = {
|
||||
"codebase_version": CODEBASE_VERSION,
|
||||
"fps": fps,
|
||||
"video": video,
|
||||
}
|
||||
if video:
|
||||
info["encoding"] = get_default_encoding()
|
||||
|
||||
lerobot_dataset = LeRobotDataset.from_preloaded(
|
||||
repo_id=repo_id,
|
||||
hf_dataset=hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
info=info,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
def save_lerobot_dataset_on_disk(lerobot_dataset):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
info = lerobot_dataset.info
|
||||
stats = lerobot_dataset.stats
|
||||
episode_data_index = lerobot_dataset.episode_data_index
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved
|
||||
hf_dataset.save_to_disk(str(local_dir / "train"))
|
||||
|
||||
save_meta_data(info, stats, episode_data_index, meta_data_dir)
|
||||
|
||||
|
||||
def push_lerobot_dataset_to_hub(lerobot_dataset, tags):
|
||||
hf_dataset = lerobot_dataset.hf_dataset
|
||||
local_dir = lerobot_dataset.videos_dir.parent
|
||||
videos_dir = lerobot_dataset.videos_dir
|
||||
repo_id = lerobot_dataset.repo_id
|
||||
video = lerobot_dataset.video
|
||||
meta_data_dir = local_dir / "meta_data"
|
||||
|
||||
if not (local_dir / "train").exists():
|
||||
raise ValueError(
|
||||
"You need to run `save_lerobot_dataset_on_disk(lerobot_dataset)` before pushing to the hub."
|
||||
)
|
||||
|
||||
hf_dataset.push_to_hub(repo_id, revision="main")
|
||||
push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
|
||||
push_dataset_card_to_hub(repo_id, revision="main", tags=tags)
|
||||
if video:
|
||||
push_videos_to_hub(repo_id, videos_dir, revision="main")
|
||||
create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)
|
||||
|
||||
|
||||
def create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds):
|
||||
if "image_writer" in dataset:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
image_writer = dataset["image_writer"]
|
||||
stop_image_writer(image_writer, timeout=20)
|
||||
|
||||
lerobot_dataset = from_dataset_to_lerobot_dataset(dataset, play_sounds)
|
||||
|
||||
if run_compute_stats:
|
||||
log_say("Computing dataset statistics", play_sounds)
|
||||
lerobot_dataset.stats = compute_stats(lerobot_dataset)
|
||||
else:
|
||||
logging.info("Skipping computation of the dataset statistics")
|
||||
lerobot_dataset.stats = {}
|
||||
|
||||
save_lerobot_dataset_on_disk(lerobot_dataset)
|
||||
|
||||
if push_to_hub:
|
||||
push_lerobot_dataset_to_hub(lerobot_dataset, tags)
|
||||
|
||||
return lerobot_dataset
|
||||
@@ -14,31 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
import warnings
|
||||
from itertools import accumulate
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import torch
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from huggingface_hub import DatasetCard, HfApi, hf_hub_download, snapshot_download
|
||||
from PIL import Image as PILImage
|
||||
from safetensors.torch import load_file
|
||||
from torchvision import transforms
|
||||
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of episodes per chunk
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
EPISODES_PATH = "meta/episodes.jsonl"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
# Metadata will go there
|
||||
@@ -48,7 +37,7 @@ This dataset was created using [LeRobot](https://github.com/huggingface/lerobot)
|
||||
"""
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
def flatten_dict(d, parent_key="", sep="/"):
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
|
||||
For example:
|
||||
@@ -67,7 +56,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
@@ -80,24 +69,6 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
||||
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
||||
@@ -109,6 +80,14 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
if isinstance(first_item, PILImage.Image):
|
||||
to_tensor = transforms.ToTensor()
|
||||
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
||||
elif isinstance(first_item, str):
|
||||
# TODO (michel-aractingi): add str2embedding via language tokenizer
|
||||
# For now we leave this part up to the user to choose how to address
|
||||
# language conditioned tasks
|
||||
pass
|
||||
elif isinstance(first_item, dict) and "path" in first_item and "timestamp" in first_item:
|
||||
# video frame will be processed downstream
|
||||
pass
|
||||
elif first_item is None:
|
||||
pass
|
||||
else:
|
||||
@@ -116,40 +95,8 @@ def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
return items_dict
|
||||
|
||||
|
||||
def _get_major_minor(version: str) -> tuple[int]:
|
||||
split = version.strip("v").split(".")
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
@cache
|
||||
def get_hf_dataset_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
@@ -169,185 +116,106 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||
return version
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
with open(local_dir / INFO_PATH) as f:
|
||||
return json.load(f)
|
||||
def load_hf_dataset(repo_id: str, version: str, root: Path, split: str) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
if root is not None:
|
||||
hf_dataset = load_from_disk(str(Path(root) / repo_id / "train"))
|
||||
# TODO(rcadene): clean this which enables getting a subset of dataset
|
||||
if split != "train":
|
||||
if "%" in split:
|
||||
raise NotImplementedError(f"We dont support splitting based on percentage for now ({split}).")
|
||||
match_from = re.search(r"train\[(\d+):\]", split)
|
||||
match_to = re.search(r"train\[:(\d+)\]", split)
|
||||
if match_from:
|
||||
from_frame_index = int(match_from.group(1))
|
||||
hf_dataset = hf_dataset.select(range(from_frame_index, len(hf_dataset)))
|
||||
elif match_to:
|
||||
to_frame_index = int(match_to.group(1))
|
||||
hf_dataset = hf_dataset.select(range(to_frame_index))
|
||||
else:
|
||||
raise ValueError(
|
||||
f'`split` ({split}) should either be "train", "train[INT:]", or "train[:INT]"'
|
||||
)
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
hf_dataset = load_dataset(repo_id, revision=safe_version, split=split)
|
||||
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
with open(local_dir / STATS_PATH) as f:
|
||||
stats = json.load(f)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
def load_episode_data_index(repo_id, version, root) -> dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/episode_data_index.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> dict[str, dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(
|
||||
repo_id, "meta_data/stats.safetensors", repo_type="dataset", revision=safe_version
|
||||
)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
|
||||
tasks = list(reader)
|
||||
def load_info(repo_id, version, root) -> dict:
|
||||
"""info contains useful information regarding the dataset that are not stored elsewhere
|
||||
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||
camera_shapes = {}
|
||||
for key, cam in robot.cameras.items():
|
||||
video_key = f"observation.images.{key}"
|
||||
camera_shapes[video_key] = {
|
||||
"width": cam.width,
|
||||
"height": cam.height,
|
||||
"channels": cam.channels,
|
||||
}
|
||||
keys = list(robot.names)
|
||||
image_keys = [] if use_videos else list(camera_shapes)
|
||||
video_keys = list(camera_shapes) if use_videos else []
|
||||
shapes = {**shapes, **camera_shapes}
|
||||
names = robot.names
|
||||
robot_type = robot.robot_type
|
||||
|
||||
return robot_type, keys, image_keys, video_keys, shapes, names
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
keys: list[str],
|
||||
image_keys: list[str],
|
||||
video_keys: list[str],
|
||||
shapes: dict,
|
||||
names: dict,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
"total_tasks": 0,
|
||||
"total_videos": 0,
|
||||
"total_chunks": 0,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"keys": keys,
|
||||
"video_keys": video_keys,
|
||||
"image_keys": image_keys,
|
||||
"shapes": shapes,
|
||||
"names": names,
|
||||
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths.values()))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
fps: int,
|
||||
tolerance_s: float,
|
||||
raise_value_error: bool = True,
|
||||
) -> bool:
|
||||
Example:
|
||||
```python
|
||||
print("frame per second used to collect the video", info["fps"])
|
||||
```
|
||||
"""
|
||||
This check is to make sure that each timestamps is separated to the next by 1/fps +/- tolerance to
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
# timestamps[2] += tolerance_s # TODO delete
|
||||
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "info.json"
|
||||
else:
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
path = hf_hub_download(repo_id, "meta_data/info.json", repo_type="dataset", revision=safe_version)
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
mask[ignored_diffs] = False
|
||||
filtered_within_tolerance = within_tolerance[mask]
|
||||
|
||||
if not torch.all(filtered_within_tolerance):
|
||||
# Track original indices before masking
|
||||
original_indices = torch.arange(len(diffs))
|
||||
filtered_indices = original_indices[mask]
|
||||
outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze()
|
||||
outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices]
|
||||
episode_indices = torch.stack(hf_dataset["episode_index"])
|
||||
|
||||
outside_tolerances = []
|
||||
for idx in outside_tolerance_indices:
|
||||
entry = {
|
||||
"timestamps": [timestamps[idx], timestamps[idx + 1]],
|
||||
"diff": diffs[idx],
|
||||
"episode_index": episode_indices[idx].item(),
|
||||
}
|
||||
outside_tolerances.append(entry)
|
||||
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""One or several timestamps unexpectedly violate the tolerance inside episode range.
|
||||
This might be due to synchronization issues with timestamps during data collection.
|
||||
\n{pformat(outside_tolerances)}"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
with open(path) as f:
|
||||
info = json.load(f)
|
||||
return info
|
||||
|
||||
|
||||
def check_delta_timestamps(
|
||||
delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True
|
||||
) -> bool:
|
||||
"""This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance.
|
||||
This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be
|
||||
actual timestamps from the dataset.
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
]
|
||||
def load_videos(repo_id, version, root) -> Path:
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "videos"
|
||||
else:
|
||||
# TODO(rcadene): we download the whole repo here. see if we can avoid this
|
||||
safe_version = get_hf_dataset_safe_version(repo_id, version)
|
||||
repo_dir = snapshot_download(repo_id, repo_type="dataset", revision=safe_version)
|
||||
path = Path(repo_dir) / "videos"
|
||||
|
||||
if len(outside_tolerance) > 0:
|
||||
if raise_value_error:
|
||||
raise ValueError(
|
||||
f"""
|
||||
The following delta_timestamps are found outside of tolerance range.
|
||||
Please make sure they are multiples of 1/{fps} +/- tolerance and adjust
|
||||
their values accordingly.
|
||||
\n{pformat(outside_tolerance)}
|
||||
"""
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
return path
|
||||
|
||||
|
||||
def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]:
|
||||
delta_indices = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist()
|
||||
|
||||
return delta_indices
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
@@ -441,7 +309,6 @@ def load_previous_and_future_frames(
|
||||
return item
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
@@ -496,7 +363,6 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torc
|
||||
return episode_data_index
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
@@ -534,7 +400,7 @@ def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None):
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
@@ -549,17 +415,12 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
) -> DatasetCard:
|
||||
def create_lerobot_dataset_card(tags: list | None = None, text: str | None = None) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
card.text += text
|
||||
return card
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
ALOHA_SINGLE_TASKS_REAL = {
|
||||
"aloha_mobile_cabinet": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
"aloha_mobile_chair": "Push the chairs in front of the desk to place them against it.",
|
||||
"aloha_mobile_elevator": "Take the elevator to the 1st floor.",
|
||||
"aloha_mobile_shrimp": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
"aloha_mobile_wash_pan": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
"aloha_mobile_wipe_wine": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
"aloha_static_battery": "Place the battery into the slot of the remote controller.",
|
||||
"aloha_static_candy": "Pick up the candy and unwrap it.",
|
||||
"aloha_static_coffee": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
"aloha_static_coffee_new": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
"aloha_static_cups_open": "Pick up the plastic cup and open its lid.",
|
||||
"aloha_static_fork_pick_up": "Pick up the fork and place it on the plate.",
|
||||
"aloha_static_pingpong_test": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
"aloha_static_pro_pencil": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
"aloha_static_screw_driver": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
"aloha_static_tape": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
"aloha_static_thread_velcro": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
"aloha_static_towel": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
"aloha_static_vinh_cup": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
"aloha_static_vinh_cup_left": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
"aloha_static_ziploc_slide": "Slide open the ziploc bag.",
|
||||
}
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
print("---------------------------------------------------------")
|
||||
name = repo_id.split("/")[1]
|
||||
single_task, tasks_col, robot_config = None, None, None
|
||||
|
||||
if "aloha" in name:
|
||||
robot_config = parse_robot_config(ALOHA_CONFIG)
|
||||
if "sim_insertion" in name:
|
||||
single_task = "Insert the peg into the socket."
|
||||
elif "sim_transfer" in name:
|
||||
single_task = "Pick up the cube with the right arm and transfer it to the left arm."
|
||||
else:
|
||||
single_task = ALOHA_SINGLE_TASKS_REAL[name]
|
||||
elif "unitreeh1" in name:
|
||||
if "fold_clothes" in name:
|
||||
single_task = "Fold the sweatshirt."
|
||||
elif "rearrange_objects" in name or "rearrange_objects" in name:
|
||||
single_task = "Put the object into the bin."
|
||||
elif "two_robot_greeting" in name:
|
||||
single_task = "Greet the other robot with a high five."
|
||||
elif "warehouse" in name:
|
||||
single_task = (
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog."
|
||||
)
|
||||
elif name != "columbia_cairlab_pusht_real" and "pusht" in name:
|
||||
single_task = "Push the T-shaped block onto the T-shaped target."
|
||||
elif "xarm_lift" in name or "xarm_push" in name:
|
||||
single_task = "Pick up the cube and lift it."
|
||||
elif name == "umi_cup_in_the_wild":
|
||||
single_task = "Put the cup on the plate."
|
||||
else:
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
try:
|
||||
convert_dataset(
|
||||
repo_id=repo_id,
|
||||
local_dir=LOCAL_DIR,
|
||||
single_task=single_task,
|
||||
tasks_col=tasks_col,
|
||||
robot_config=robot_config,
|
||||
)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
except Exception:
|
||||
status = f"{repo_id}: failed\n {traceback.format_exc()}"
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
continue
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_convert()
|
||||
@@ -1,814 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script will help you convert any LeRobot dataset already pushed to the hub from codebase version 1.6 to
|
||||
2.0. You will be required to provide the 'tasks', which is a short but accurate description in plain English
|
||||
for each of the task performed in the dataset. This will allow to easily train models with task-conditionning.
|
||||
|
||||
We support 3 different scenarios for these tasks (see instructions below):
|
||||
1. Single task dataset: all episodes of your dataset have the same single task.
|
||||
2. Single task episodes: the episodes of your dataset each contain a single task but they can differ from
|
||||
one episode to the next.
|
||||
3. Multi task episodes: episodes of your dataset may each contain several different tasks.
|
||||
|
||||
|
||||
Can you can also provide a robot config .yaml file (not mandatory) to this script via the option
|
||||
'--robot-config' so that it writes information about the robot (robot type, motors names) this dataset was
|
||||
recorded with. For now, only Aloha/Koch type robots are supported with this option.
|
||||
|
||||
|
||||
# 1. Single task dataset
|
||||
If your dataset contains a single task, you can simply provide it directly via the CLI with the
|
||||
'--single-task' option.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/aloha_sim_insertion_human_image \
|
||||
--single-task "Insert the peg into the socket." \
|
||||
--robot-config lerobot/configs/robot/aloha.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id aliberts/koch_tutorial \
|
||||
--single-task "Pick the Lego block and drop it in the box on the right." \
|
||||
--robot-config lerobot/configs/robot/koch.yaml \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
|
||||
# 2. Single task episodes
|
||||
If your dataset is a multi-task dataset, you have two options to provide the tasks to this script:
|
||||
|
||||
- If your dataset already contains a language instruction column in its parquet file, you can simply provide
|
||||
this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
|
||||
- If your dataset doesn't contain a language instruction, you should provide the path to a .json file with the
|
||||
'--tasks-path' arg. This file should have the following structure where keys correspond to each
|
||||
episode_index in the dataset, and values are the language instruction for that episode.
|
||||
|
||||
Example:
|
||||
|
||||
```json
|
||||
{
|
||||
"0": "Do something",
|
||||
"1": "Do something else",
|
||||
"2": "Do something",
|
||||
"3": "Go there",
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
# 3. Multi task episodes
|
||||
If you have multiple tasks per episodes, your dataset should contain a language instruction column in its
|
||||
parquet file, and you must provide this column's name with the '--tasks-col' arg.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \
|
||||
--repo-id lerobot/stanford_kuka_multimodal_dataset \
|
||||
--tasks-col "language_instruction" \
|
||||
--local-dir data
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import filecmp
|
||||
import json
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
EPISODES_PATH,
|
||||
INFO_PATH,
|
||||
STATS_PATH,
|
||||
TASKS_PATH,
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
V20 = "v2.0"
|
||||
|
||||
GITATTRIBUTES_REF = "aliberts/gitattributes_reference"
|
||||
V1_VIDEO_FILE = "{video_key}_episode_{episode_index:06d}.mp4"
|
||||
V1_INFO_PATH = "meta_data/info.json"
|
||||
V1_STATS_PATH = "meta_data/stats.safetensors"
|
||||
|
||||
|
||||
def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]:
|
||||
robot_cfg = init_hydra_config(config_path, config_overrides)
|
||||
if robot_cfg["robot_type"] in ["aloha", "koch"]:
|
||||
state_names = [
|
||||
f"{arm}_{motor}" if len(robot_cfg["follower_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["follower_arms"]
|
||||
for motor in robot_cfg["follower_arms"][arm]["motors"]
|
||||
]
|
||||
action_names = [
|
||||
# f"{arm}_{motor}" for arm in ["left", "right"] for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
f"{arm}_{motor}" if len(robot_cfg["leader_arms"]) > 1 else motor
|
||||
for arm in robot_cfg["leader_arms"]
|
||||
for motor in robot_cfg["leader_arms"][arm]["motors"]
|
||||
]
|
||||
# elif robot_cfg["robot_type"] == "stretch3": TODO
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Please provide robot_config={'robot_type': ..., 'names': ...} directly to convert_dataset()."
|
||||
)
|
||||
|
||||
return {
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> dict:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
serialized_stats = {key: value.tolist() for key, value in stats.items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
|
||||
json_path = v2_dir / STATS_PATH
|
||||
json_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(json_path, "w") as f:
|
||||
json.dump(serialized_stats, f, indent=4)
|
||||
|
||||
# Sanity check
|
||||
with open(json_path) as f:
|
||||
stats_json = json.load(f)
|
||||
|
||||
stats_json = flatten_dict(stats_json)
|
||||
stats_json = {key: torch.tensor(value) for key, value in stats_json.items()}
|
||||
for key in stats:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_keys(dataset: Dataset) -> dict[str, list]:
|
||||
sequence_keys, image_keys, video_keys = [], [], []
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
sequence_keys.append(key)
|
||||
elif isinstance(ft, datasets.Image):
|
||||
image_keys.append(key)
|
||||
elif ft._type == "VideoFrame":
|
||||
video_keys.append(key)
|
||||
|
||||
return {
|
||||
"sequence": sequence_keys,
|
||||
"image": image_keys,
|
||||
"video": video_keys,
|
||||
}
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
tasks = list(set(tasks_by_episodes.values()))
|
||||
tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)}
|
||||
episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int)
|
||||
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
return dataset, tasks
|
||||
|
||||
|
||||
def add_task_index_from_tasks_col(
|
||||
dataset: Dataset, tasks_col: str
|
||||
) -> tuple[Dataset, dict[str, list[str]], list[str]]:
|
||||
df = dataset.to_pandas()
|
||||
|
||||
# HACK: This is to clean some of the instructions in our version of Open X datasets
|
||||
prefix_to_clean = "tf.Tensor(b'"
|
||||
suffix_to_clean = "', shape=(), dtype=string)"
|
||||
df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean)
|
||||
|
||||
# Create task_index col
|
||||
tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict()
|
||||
tasks = df[tasks_col].unique().tolist()
|
||||
tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)}
|
||||
df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int)
|
||||
|
||||
# Build the dataset back from df
|
||||
features = dataset.features
|
||||
features["task_index"] = datasets.Value(dtype="int64")
|
||||
dataset = Dataset.from_pandas(df, features=features, split="train")
|
||||
dataset = dataset.remove_columns(tasks_col)
|
||||
|
||||
return dataset, tasks, tasks_by_episode
|
||||
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
keys: dict[str, list],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.remove_columns(keys["video"])._data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
output_file = output_dir / DEFAULT_PARQUET_PATH.format(
|
||||
episode_chunk=ep_chunk, episode_index=ep_idx
|
||||
)
|
||||
pq.write_table(ep_table, output_file)
|
||||
|
||||
return episode_lengths
|
||||
|
||||
|
||||
def move_videos(
|
||||
repo_id: str,
|
||||
video_keys: list[str],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
work_dir: Path,
|
||||
clean_gittatributes: Path,
|
||||
branch: str = "main",
|
||||
) -> None:
|
||||
"""
|
||||
HACK: Since HfApi() doesn't provide a way to move files directly in a repo, this function will run git
|
||||
commands to fetch git lfs video files references to move them into subdirectories without having to
|
||||
actually download them.
|
||||
"""
|
||||
_lfs_clone(repo_id, work_dir, branch)
|
||||
|
||||
videos_moved = False
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")]
|
||||
if len(video_files) == 0:
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
|
||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||
|
||||
current_gittatributes = work_dir / ".gitattributes"
|
||||
if not filecmp.cmp(current_gittatributes, clean_gittatributes, shallow=False):
|
||||
fix_gitattributes(work_dir, current_gittatributes, clean_gittatributes)
|
||||
|
||||
if lfs_untracked_videos:
|
||||
fix_lfs_video_files_tracking(work_dir, video_files)
|
||||
|
||||
if videos_moved:
|
||||
return
|
||||
|
||||
video_dirs = sorted(work_dir.glob("videos*/"))
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
for vid_key in video_keys:
|
||||
chunk_dir = "/".join(DEFAULT_VIDEO_PATH.split("/")[:-1]).format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key
|
||||
)
|
||||
(work_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
target_path = DEFAULT_VIDEO_PATH.format(
|
||||
episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx
|
||||
)
|
||||
video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx)
|
||||
if len(video_dirs) == 1:
|
||||
video_path = video_dirs[0] / video_file
|
||||
else:
|
||||
for dir in video_dirs:
|
||||
if (dir / video_file).is_file():
|
||||
video_path = dir / video_file
|
||||
break
|
||||
|
||||
video_path.rename(work_dir / target_path)
|
||||
|
||||
commit_message = "Move video files into chunk subdirectories"
|
||||
subprocess.run(["git", "add", "."], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None:
|
||||
"""
|
||||
HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case,
|
||||
there's no other option than to download the actual files and reupload them with lfs tracking.
|
||||
"""
|
||||
for i in range(0, len(lfs_untracked_videos), 100):
|
||||
files = lfs_untracked_videos[i : i + 100]
|
||||
try:
|
||||
subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("git rm --cached ERROR:")
|
||||
print(e.stderr)
|
||||
subprocess.run(["git", "add", *files], cwd=work_dir, check=True)
|
||||
|
||||
commit_message = "Track video files with git lfs"
|
||||
subprocess.run(["git", "commit", "-m", commit_message], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None:
|
||||
shutil.copyfile(clean_gittatributes, current_gittatributes)
|
||||
subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True)
|
||||
subprocess.run(["git", "push"], cwd=work_dir, check=True)
|
||||
|
||||
|
||||
def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None:
|
||||
subprocess.run(["git", "lfs", "install"], cwd=work_dir, check=True)
|
||||
repo_url = f"https://huggingface.co/datasets/{repo_id}"
|
||||
env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files
|
||||
subprocess.run(
|
||||
["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)],
|
||||
check=True,
|
||||
env=env,
|
||||
)
|
||||
|
||||
|
||||
def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]:
|
||||
lfs_tracked_files = subprocess.run(
|
||||
["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True
|
||||
)
|
||||
lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines())
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def _get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def _get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**_get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
||||
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image = dataset[0][img_key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
single_task: str | None = None,
|
||||
tasks_path: Path | None = None,
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
test_branch: str | None = None,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/"
|
||||
)
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
keys = get_keys(dataset)
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
warnings.warn(
|
||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||
stacklevel=1,
|
||||
)
|
||||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
# Episodes & chunks
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(keys["video"])
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
|
||||
# Tasks
|
||||
if single_task:
|
||||
tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices}
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
# tasks = list(set(tasks_by_episodes.values()))
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||
|
||||
# Videos
|
||||
if len(keys["video"]) > 0:
|
||||
assert metadata_v1.get("video", False)
|
||||
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
||||
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
move_videos(
|
||||
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||
for img_key in keys["video"]:
|
||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
video_shapes = {}
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
# Names
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
names = robot_config["names"]
|
||||
if "observation.effort" in keys["sequence"]:
|
||||
names["observation.effort"] = names["observation.state"]
|
||||
if "observation.velocity" in keys["sequence"]:
|
||||
names["observation.velocity"] = names["observation.state"]
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
names = get_generic_motor_names(sequence_shapes)
|
||||
repo_tags = None
|
||||
|
||||
assert set(names) == set(keys["sequence"])
|
||||
for key in sequence_shapes:
|
||||
assert len(names[key]) == sequence_shapes[key]
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
for ep_idx in episode_indices
|
||||
]
|
||||
write_jsonlines(episodes, v20_dir / EPISODES_PATH)
|
||||
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
"total_tasks": len(tasks),
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"keys": keys["sequence"],
|
||||
"video_keys": keys["video"],
|
||||
"image_keys": keys["image"],
|
||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||
"names": names,
|
||||
"videos": videos_info,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="data",
|
||||
folder_path=v20_dir / "data",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
hub_api.upload_folder(
|
||||
repo_id=repo_id,
|
||||
path_in_repo="meta",
|
||||
folder_path=v20_dir / "meta",
|
||||
repo_type="dataset",
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
# TODO:
|
||||
# - [X] Add shapes
|
||||
# - [X] Add keys
|
||||
# - [X] Add paths
|
||||
# - [X] convert stats.json
|
||||
# - [X] Add task.json
|
||||
# - [X] Add names
|
||||
# - [X] Add robot_type
|
||||
# - [X] Add splits
|
||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||
# - [X] Handle multitask datasets
|
||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||
# - [X] Add test-branch
|
||||
# - [X] Use jsonlines for episodes
|
||||
# - [X] Add sanity checks (encoding, shapes)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
task_args = parser.add_mutually_exclusive_group(required=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the single task performed in the dataset.",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-col",
|
||||
type=str,
|
||||
help="The name of the column containing language instructions",
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--tasks-path",
|
||||
type=Path,
|
||||
help="The path to a .json file containing one language instruction for each episode_index",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-config",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Path to the robot's config yaml the dataset during conversion.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override the robot config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repo branch to test your conversion first (e.g. 'v2.0.test')",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.local_dir:
|
||||
args.local_dir = Path("/tmp/lerobot_dataset_v2")
|
||||
|
||||
robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None
|
||||
del args.robot_config, args.robot_overrides
|
||||
|
||||
convert_dataset(**vars(args), robot_config=robot_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
main()
|
||||
@@ -27,8 +27,45 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
|
||||
|
||||
def load_from_videos(
|
||||
item: dict[str, torch.Tensor],
|
||||
video_frame_keys: list[str],
|
||||
videos_dir: Path,
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
):
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault.
|
||||
This probably happens because a memory reference to the video loader is created in the main process and a
|
||||
subprocess fails to access it.
|
||||
"""
|
||||
# since video path already contains "videos" (e.g. videos_dir="data/videos", path="videos/episode_0.mp4")
|
||||
data_dir = videos_dir.parent
|
||||
|
||||
for key in video_frame_keys:
|
||||
if isinstance(item[key], list):
|
||||
# load multiple frames at once (expected when delta_timestamps is not None)
|
||||
timestamps = [frame["timestamp"] for frame in item[key]]
|
||||
paths = [frame["path"] for frame in item[key]]
|
||||
if len(set(paths)) > 1:
|
||||
raise NotImplementedError("All video paths are expected to be the same for now.")
|
||||
video_path = data_dir / paths[0]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames
|
||||
else:
|
||||
# load one frame
|
||||
timestamps = [item[key]["timestamp"]]
|
||||
video_path = data_dir / item[key]["path"]
|
||||
|
||||
frames = decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
item[key] = frames[0]
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
video_path: Path | str,
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str = "pyav",
|
||||
@@ -126,8 +163,8 @@ def decode_video_frames_torchvision(
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
|
||||
@@ -67,6 +67,7 @@ class DiffusionConfig:
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
@@ -130,6 +131,7 @@ class DiffusionConfig:
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
@@ -182,8 +182,13 @@ class DiffusionModel(nn.Module):
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
global_cond_dim += encoders[0].feature_dim * num_images
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
@@ -239,16 +244,32 @@ class DiffusionModel(nn.Module):
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||
# Extract image features.
|
||||
if self._use_images:
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
|
||||
@@ -51,6 +51,13 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]:
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
|
||||
return TDMPCPolicy, TDMPCConfig
|
||||
|
||||
elif name == "tdmpc2":
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.modeling_tdmpc2 import TDMPC2Policy
|
||||
|
||||
return TDMPC2Policy, TDMPC2Config
|
||||
|
||||
elif name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
193
lerobot/common/policies/tdmpc2/configuration_tdmpc2.py
Normal file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su,
|
||||
# and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class TDMPC2Config:
|
||||
"""Configuration class for TDMPC2Policy.
|
||||
|
||||
Defaults are configured for training with xarm_lift_medium_replay providing proprioceptive and single
|
||||
camera observations.
|
||||
|
||||
The parameters you will most likely need to change are the ones which depend on the environment / sensors.
|
||||
Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`.
|
||||
|
||||
Args:
|
||||
n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google
|
||||
action repeats in Q-learning or ask your favorite chatbot)
|
||||
horizon: Horizon for model predictive control.
|
||||
n_action_steps: Number of action steps to take from the plan given by model predictive control. This
|
||||
is an alternative to using action repeats. If this is set to more than 1, then we require
|
||||
`n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this
|
||||
approach of using multiple steps from the plan is not in the original implementation.
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents
|
||||
the input data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96],
|
||||
indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't
|
||||
include batch dimension or temporal dimension.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents
|
||||
the output data name, and the value is a list indicating the dimensions of the corresponding data.
|
||||
For example, "action" refers to an output shape of [14], indicating 14-dimensional actions.
|
||||
Importantly, `output_shapes` doesn't include batch dimension or temporal dimension.
|
||||
input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"),
|
||||
and the value specifies the normalization mode to apply. The two available modes are "mean_std"
|
||||
which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a
|
||||
[-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to
|
||||
match the original implementation.
|
||||
output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the
|
||||
original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping
|
||||
to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max"
|
||||
normalization mode here.
|
||||
image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding.
|
||||
state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding.
|
||||
latent_dim: Observation's latent embedding dimension.
|
||||
q_ensemble_size: Number of Q function estimators to use in an ensemble for uncertainty estimation.
|
||||
mlp_dim: Hidden dimension of MLPs used for modelling the dynamics encoder, reward function, policy
|
||||
(π), Q ensemble, and V.
|
||||
discount: Discount factor (γ) to use for the reinforcement learning formalism.
|
||||
use_mpc: Whether to use model predictive control. The alternative is to just sample the policy model
|
||||
(π) for each step.
|
||||
cem_iterations: Number of iterations for the MPPI/CEM loop in MPC.
|
||||
max_std: Maximum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
min_std: Minimum standard deviation for noise applied to actions sampled from the policy model (π).
|
||||
Doubles up as the minimum standard deviation for actions sampled from the gaussian PDF in CEM.
|
||||
n_gaussian_samples: Number of samples to draw from the gaussian distribution every CEM iteration. Must
|
||||
be non-zero.
|
||||
n_pi_samples: Number of samples to draw from the policy / world model rollout every CEM iteration. Can
|
||||
be zero.
|
||||
n_elites: The number of elite samples to use for updating the gaussian parameters every CEM iteration.
|
||||
elite_weighting_temperature: The temperature to use for softmax weighting (by trajectory value) of the
|
||||
elites, when updating the gaussian parameters for CEM.
|
||||
max_random_shift_ratio: Maximum random shift (as a proportion of the image size) to apply to the
|
||||
image(s) (in units of pixels) for training-time augmentation. If set to 0, no such augmentation
|
||||
is applied. Note that the input images are assumed to be square for this augmentation.
|
||||
reward_coeff: Loss weighting coefficient for the reward regression loss.
|
||||
value_coeff: Loss weighting coefficient for both the state-action value (Q) TD loss, and the state
|
||||
value (V) expectile regression loss.
|
||||
consistency_coeff: Loss weighting coefficient for the consistency loss.
|
||||
temporal_decay_coeff: Exponential decay coefficient for decaying the loss coefficient for future time-
|
||||
steps. Hint: each loss computation involves `horizon` steps worth of actions starting from the
|
||||
current time step.
|
||||
target_model_momentum: Momentum (α) used for EMA updates of the target models. Updates are calculated
|
||||
as ϕ ← αϕ + (1-α)θ where ϕ are the parameters of the target model and θ are the parameters of the
|
||||
model being trained.
|
||||
"""
|
||||
|
||||
# Input / output structure.
|
||||
n_action_repeats: int = 1
|
||||
horizon: int = 3
|
||||
n_action_steps: int = 1
|
||||
|
||||
input_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"observation.image": [3, 84, 84],
|
||||
"observation.state": [4],
|
||||
}
|
||||
)
|
||||
output_shapes: dict[str, list[int]] = field(
|
||||
default_factory=lambda: {
|
||||
"action": [4],
|
||||
}
|
||||
)
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes: dict[str, str] | None = None
|
||||
output_normalization_modes: dict[str, str] = field(
|
||||
default_factory=lambda: {"action": "min_max"},
|
||||
)
|
||||
|
||||
# Architecture / modeling.
|
||||
# Neural networks.
|
||||
image_encoder_hidden_dim: int = 32
|
||||
state_encoder_hidden_dim: int = 256
|
||||
latent_dim: int = 512
|
||||
q_ensemble_size: int = 5
|
||||
num_enc_layers: int = 2
|
||||
mlp_dim: int = 512
|
||||
# Reinforcement learning.
|
||||
discount: float = 0.9
|
||||
simnorm_dim: int = 8
|
||||
dropout: float = 0.01
|
||||
|
||||
# actor
|
||||
log_std_min: float = -10
|
||||
log_std_max: float = 2
|
||||
|
||||
# critic
|
||||
num_bins: int = 101
|
||||
vmin: int = -10
|
||||
vmax: int = +10
|
||||
|
||||
# Inference.
|
||||
use_mpc: bool = True
|
||||
cem_iterations: int = 6
|
||||
max_std: float = 2.0
|
||||
min_std: float = 0.05
|
||||
n_gaussian_samples: int = 512
|
||||
n_pi_samples: int = 24
|
||||
n_elites: int = 64
|
||||
elite_weighting_temperature: float = 0.5
|
||||
|
||||
# Training and loss computation.
|
||||
max_random_shift_ratio: float = 0.0476
|
||||
# Loss coefficients.
|
||||
reward_coeff: float = 0.1
|
||||
value_coeff: float = 0.1
|
||||
consistency_coeff: float = 20.0
|
||||
entropy_coef: float = 1e-4
|
||||
temporal_decay_coeff: float = 0.5
|
||||
# Target model. NOTE (michel_aractingi) this is equivelant to
|
||||
# 1 - target_model_momentum of our TD-MPC1 implementation because
|
||||
# of the use of `torch.lerp`
|
||||
target_model_momentum: float = 0.01
|
||||
|
||||
def __post_init__(self):
|
||||
"""Input validation (not exhaustive)."""
|
||||
# There should only be one image key.
|
||||
image_keys = {k for k in self.input_shapes if k.startswith("observation.image")}
|
||||
if len(image_keys) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} handles at most one image for now. Got image keys {image_keys}."
|
||||
)
|
||||
if len(image_keys) > 0:
|
||||
image_key = next(iter(image_keys))
|
||||
if self.input_shapes[image_key][-2] != self.input_shapes[image_key][-1]:
|
||||
# TODO(alexander-soare): This limitation is solely because of code in the random shift
|
||||
# augmentation. It should be able to be removed.
|
||||
raise ValueError(
|
||||
f"Only square images are handled now. Got image shape {self.input_shapes[image_key]}."
|
||||
)
|
||||
if self.n_gaussian_samples <= 0:
|
||||
raise ValueError(
|
||||
f"The number of guassian samples for CEM should be non-zero. Got `{self.n_gaussian_samples=}`"
|
||||
)
|
||||
if self.output_normalization_modes != {"action": "min_max"}:
|
||||
raise ValueError(
|
||||
"TD-MPC assumes the action space dimensions to all be in [-1, 1]. Therefore it is strongly "
|
||||
f"advised that you stick with the default. See {self.__class__.__name__} docstring for more "
|
||||
"information."
|
||||
)
|
||||
if self.n_action_steps > 1:
|
||||
if self.n_action_repeats != 1:
|
||||
raise ValueError(
|
||||
"If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1."
|
||||
)
|
||||
if not self.use_mpc:
|
||||
raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.")
|
||||
if self.n_action_steps > self.horizon:
|
||||
raise ValueError("`n_action_steps` must be less than or equal to `horizon`.")
|
||||
834
lerobot/common/policies/tdmpc2/modeling_tdmpc2.py
Normal file
@@ -0,0 +1,834 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 Nicklas Hansen and The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Implementation of TD-MPC2: Scalable, Robust World Models for Continuous Control
|
||||
|
||||
We refer to the main paper and codebase:
|
||||
TD-MPC2 paper: (https://arxiv.org/abs/2310.16828)
|
||||
TD-MPC2 code: (https://github.com/nicklashansen/tdmpc2)
|
||||
"""
|
||||
|
||||
# ruff: noqa: N806
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.tdmpc2.configuration_tdmpc2 import TDMPC2Config
|
||||
from lerobot.common.policies.tdmpc2.tdmpc2_utils import (
|
||||
NormedLinear,
|
||||
SimNorm,
|
||||
gaussian_logprob,
|
||||
soft_cross_entropy,
|
||||
squash,
|
||||
two_hot_inv,
|
||||
)
|
||||
from lerobot.common.policies.utils import get_device_from_parameters, populate_queues
|
||||
|
||||
|
||||
class TDMPC2Policy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "tdmpc2"],
|
||||
):
|
||||
"""Implementation of TD-MPC2 learning + inference."""
|
||||
|
||||
name = "tdmpc2"
|
||||
|
||||
def __init__(
|
||||
self, config: TDMPC2Config | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: Policy configuration class instance or None, in which case the default instantiation of
|
||||
the configuration class is used.
|
||||
dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected
|
||||
that they will be passed with a call to `load_state_dict` before the policy is used.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = TDMPC2Config()
|
||||
self.config = config
|
||||
self.model = TDMPC2WorldModel(config)
|
||||
# TODO (michel-aractingi) temp fix for gpu
|
||||
self.model = self.model.to("cuda:0")
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
image_keys = [k for k in config.input_shapes if k.startswith("observation.image")]
|
||||
# Note: This check is covered in the post-init of the config but have a sanity check just in case.
|
||||
self._use_image = False
|
||||
self._use_env_state = False
|
||||
if len(image_keys) > 0:
|
||||
assert len(image_keys) == 1
|
||||
self._use_image = True
|
||||
self.input_image_key = image_keys[0]
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
|
||||
self.scale = RunningScale(self.config.target_model_momentum)
|
||||
self.discount = (
|
||||
self.config.discount
|
||||
) # TODO (michel-aractingi) downscale discount according to episode length
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be
|
||||
called on `env.reset()`
|
||||
"""
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
# Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start
|
||||
# CEM for the next step.
|
||||
self._prev_mean: torch.Tensor | None = None
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
|
||||
self._queues = populate_queues(self._queues, batch)
|
||||
|
||||
# When the action queue is depleted, populate it again by querying the policy.
|
||||
if len(self._queues["action"]) == 0:
|
||||
batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch}
|
||||
|
||||
# Remove the time dimensions as it is not handled yet.
|
||||
for key in batch:
|
||||
assert batch[key].shape[1] == 1
|
||||
batch[key] = batch[key][:, 0]
|
||||
|
||||
# NOTE: Order of observations matters here.
|
||||
encode_keys = []
|
||||
if self._use_image:
|
||||
encode_keys.append("observation.image")
|
||||
if self._use_env_state:
|
||||
encode_keys.append("observation.environment_state")
|
||||
encode_keys.append("observation.state")
|
||||
z = self.model.encode({k: batch[k] for k in encode_keys})
|
||||
if self.config.use_mpc: # noqa: SIM108
|
||||
actions = self.plan(z) # (horizon, batch, action_dim)
|
||||
else:
|
||||
# Plan with the policy (π) alone. This always returns one action so unsqueeze to get a
|
||||
# sequence dimension like in the MPC branch.
|
||||
actions = self.model.pi(z)[0].unsqueeze(0)
|
||||
|
||||
actions = torch.clamp(actions, -1, +1)
|
||||
|
||||
actions = self.unnormalize_outputs({"action": actions})["action"]
|
||||
|
||||
if self.config.n_action_repeats > 1:
|
||||
for _ in range(self.config.n_action_repeats):
|
||||
self._queues["action"].append(actions[0])
|
||||
else:
|
||||
# Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action.
|
||||
self._queues["action"].extend(actions[: self.config.n_action_steps])
|
||||
|
||||
action = self._queues["action"].popleft()
|
||||
return action
|
||||
|
||||
@torch.no_grad()
|
||||
def plan(self, z: Tensor) -> Tensor:
|
||||
"""Plan sequence of actions using TD-MPC inference.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim,) tensor for the initial state.
|
||||
Returns:
|
||||
(horizon, batch, action_dim,) tensor for the planned trajectory of actions.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch_size = z.shape[0]
|
||||
|
||||
# Sample Nπ trajectories from the policy.
|
||||
pi_actions = torch.empty(
|
||||
self.config.horizon,
|
||||
self.config.n_pi_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=device,
|
||||
)
|
||||
if self.config.n_pi_samples > 0:
|
||||
_z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples)
|
||||
for t in range(self.config.horizon):
|
||||
# Note: Adding a small amount of noise here doesn't hurt during inference and may even be
|
||||
# helpful for CEM.
|
||||
pi_actions[t] = self.model.pi(_z)[0]
|
||||
_z = self.model.latent_dynamics(_z, pi_actions[t])
|
||||
|
||||
# In the CEM loop we will need this for a call to estimate_value with the gaussian sampled
|
||||
# trajectories.
|
||||
z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples)
|
||||
|
||||
# Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization
|
||||
# algorithm.
|
||||
# The initial mean and standard deviation for the cross-entropy method (CEM).
|
||||
mean = torch.zeros(
|
||||
self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device
|
||||
)
|
||||
# Maybe warm start CEM with the mean from the previous step.
|
||||
if self._prev_mean is not None:
|
||||
mean[:-1] = self._prev_mean[1:]
|
||||
std = self.config.max_std * torch.ones_like(mean)
|
||||
|
||||
for _ in range(self.config.cem_iterations):
|
||||
# Randomly sample action trajectories for the gaussian distribution.
|
||||
std_normal_noise = torch.randn(
|
||||
self.config.horizon,
|
||||
self.config.n_gaussian_samples,
|
||||
batch_size,
|
||||
self.config.output_shapes["action"][0],
|
||||
device=std.device,
|
||||
)
|
||||
gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1)
|
||||
|
||||
# Compute elite actions.
|
||||
actions = torch.cat([gaussian_actions, pi_actions], dim=1)
|
||||
value = self.estimate_value(z, actions).nan_to_num_(0).squeeze()
|
||||
elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch)
|
||||
elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch)
|
||||
# (horizon, n_elites, batch, action_dim)
|
||||
elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1)
|
||||
|
||||
# Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites.
|
||||
max_value = elite_value.max(0, keepdim=True)[0] # (1, batch)
|
||||
# The weighting is a softmax over trajectory values. Note that this is not the same as the usage
|
||||
# of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This
|
||||
# makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²).
|
||||
score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value))
|
||||
score /= score.sum(axis=0, keepdim=True)
|
||||
# (horizon, batch, action_dim)
|
||||
mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) / (
|
||||
einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9
|
||||
)
|
||||
std = torch.sqrt(
|
||||
torch.sum(
|
||||
einops.rearrange(score, "n b -> n b 1")
|
||||
* (elite_actions - einops.rearrange(mean, "h b d -> h 1 b d")) ** 2,
|
||||
dim=1,
|
||||
)
|
||||
/ (einops.rearrange(score.sum(0), "b -> 1 b 1") + 1e-9)
|
||||
).clamp_(self.config.min_std, self.config.max_std)
|
||||
|
||||
# Keep track of the mean for warm-starting subsequent steps.
|
||||
self._prev_mean = mean
|
||||
|
||||
# Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax
|
||||
# scores from the last iteration.
|
||||
actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)]
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def estimate_value(self, z: Tensor, actions: Tensor):
|
||||
"""Estimates the value of a trajectory as per eqn 4 of the FOWM paper.
|
||||
|
||||
Args:
|
||||
z: (batch, latent_dim) tensor of initial latent states.
|
||||
actions: (horizon, batch, action_dim) tensor of action trajectories.
|
||||
Returns:
|
||||
(batch,) tensor of values.
|
||||
"""
|
||||
# Initialize return and running discount factor.
|
||||
G, running_discount = 0, 1
|
||||
# Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics
|
||||
# model. Keep track of return.
|
||||
for t in range(actions.shape[0]):
|
||||
# Estimate the next state (latent) and reward.
|
||||
z, reward = self.model.latent_dynamics_and_reward(z, actions[t], discretize_reward=True)
|
||||
# Update the return and running discount.
|
||||
G += running_discount * reward
|
||||
running_discount *= self.config.discount
|
||||
|
||||
# next_action = self.model.pi(z)[0] # (batch, action_dim)
|
||||
# terminal_values = self.model.Qs(z, next_action, return_type="avg") # (ensemble, batch)
|
||||
|
||||
return G + running_discount * self.model.Qs(z, self.model.pi(z)[0], return_type="avg")
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
device = get_device_from_parameters(self)
|
||||
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self._use_image:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.image"] = batch[self.input_image_key]
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
info = {}
|
||||
|
||||
# (b, t) -> (t, b)
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"] # (t, b, action_dim)
|
||||
reward = batch["next.reward"] # (t, b)
|
||||
observations = {k: v for k, v in batch.items() if k.startswith("observation.")}
|
||||
|
||||
# Apply random image augmentations.
|
||||
if self._use_image and self.config.max_random_shift_ratio > 0:
|
||||
observations["observation.image"] = flatten_forward_unflatten(
|
||||
partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio),
|
||||
observations["observation.image"],
|
||||
)
|
||||
|
||||
# Get the current observation for predicting trajectories, and all future observations for use in
|
||||
# the latent consistency loss and TD loss.
|
||||
current_observation, next_observations = {}, {}
|
||||
for k in observations:
|
||||
current_observation[k] = observations[k][0]
|
||||
next_observations[k] = observations[k][1:]
|
||||
horizon, batch_size = next_observations[
|
||||
"observation.image" if self._use_image else "observation.environment_state"
|
||||
].shape[:2]
|
||||
|
||||
# Run latent rollout using the latent dynamics model and policy model.
|
||||
# Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action
|
||||
# gives us a next `z`.
|
||||
batch_size = batch["index"].shape[0]
|
||||
z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device)
|
||||
z_preds[0] = self.model.encode(current_observation)
|
||||
reward_preds = torch.empty(horizon, batch_size, self.config.num_bins, device=device)
|
||||
for t in range(horizon):
|
||||
z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t])
|
||||
|
||||
# Compute Q value predictions based on the latent rollout.
|
||||
q_preds_ensemble = self.model.Qs(
|
||||
z_preds[:-1], action, return_type="all"
|
||||
) # (ensemble, horizon, batch)
|
||||
info.update({"Q": q_preds_ensemble.mean().item()})
|
||||
|
||||
# Compute various targets with stopgrad.
|
||||
with torch.no_grad():
|
||||
# Latent state consistency targets for consistency loss.
|
||||
z_targets = self.model.encode(next_observations)
|
||||
|
||||
# Compute the TD-target from a reward and the next observation
|
||||
pi = self.model.pi(z_targets)[0]
|
||||
td_targets = (
|
||||
reward
|
||||
+ self.config.discount
|
||||
* self.model.Qs(z_targets, pi, return_type="min", target=True).squeeze()
|
||||
)
|
||||
|
||||
# Compute losses.
|
||||
# Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the
|
||||
# future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch).
|
||||
temporal_loss_coeffs = torch.pow(
|
||||
self.config.temporal_decay_coeff, torch.arange(horizon, device=device)
|
||||
).unsqueeze(-1)
|
||||
|
||||
# Compute consistency loss as MSE loss between latents predicted from the rollout and latents
|
||||
# predicted from the (target model's) observation encoder.
|
||||
consistency_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1)
|
||||
# `z_preds` depends on the current observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# `z_targets` depends on the next observation.
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
# Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset
|
||||
# rewards.
|
||||
reward_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* soft_cross_entropy(reward_preds, reward, self.config).mean(1)
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
ce_value_loss = 0.0
|
||||
for i in range(self.config.q_ensemble_size):
|
||||
ce_value_loss += soft_cross_entropy(q_preds_ensemble[i], td_targets, self.config).mean(1)
|
||||
|
||||
q_value_loss = (
|
||||
(
|
||||
temporal_loss_coeffs
|
||||
* ce_value_loss
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
)
|
||||
.sum(0)
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1.
|
||||
# We won't need these gradients again so detach.
|
||||
z_preds = z_preds.detach()
|
||||
action_preds, _, log_pis, _ = self.model.pi(z_preds[:-1])
|
||||
|
||||
with torch.no_grad():
|
||||
# avoid unnessecary computation of the gradients during policy optimization
|
||||
# TODO (michel-aractingi): the same logic should be extended when adding task embeddings
|
||||
qs = self.model.Qs(z_preds[:-1], action_preds, return_type="avg")
|
||||
self.scale.update(qs[0])
|
||||
qs = self.scale(qs)
|
||||
|
||||
pi_loss = (
|
||||
(self.config.entropy_coef * log_pis - qs).mean(dim=2)
|
||||
* temporal_loss_coeffs
|
||||
# `action_preds` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
loss = (
|
||||
self.config.consistency_coeff * consistency_loss
|
||||
+ self.config.reward_coeff * reward_loss
|
||||
+ self.config.value_coeff * q_value_loss
|
||||
+ pi_loss
|
||||
)
|
||||
|
||||
info.update(
|
||||
{
|
||||
"consistency_loss": consistency_loss.item(),
|
||||
"reward_loss": reward_loss.item(),
|
||||
"Q_value_loss": q_value_loss.item(),
|
||||
"pi_loss": pi_loss.item(),
|
||||
"loss": loss,
|
||||
"sum_loss": loss.item() * self.config.horizon,
|
||||
"pi_scale": float(self.scale.value),
|
||||
}
|
||||
)
|
||||
|
||||
# Undo (b, t) -> (t, b).
|
||||
for key in batch:
|
||||
if batch[key].ndim > 1:
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
return info
|
||||
|
||||
def update(self):
|
||||
"""Update the target model's using polyak averaging."""
|
||||
self.model.update_target_Q()
|
||||
|
||||
|
||||
class TDMPC2WorldModel(nn.Module):
|
||||
"""Latent dynamics model used in TD-MPC2."""
|
||||
|
||||
def __init__(self, config: TDMPC2Config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self._encoder = TDMPC2ObservationEncoder(config)
|
||||
|
||||
# Define latent dynamics head
|
||||
self._dynamics = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
|
||||
# Define reward head
|
||||
self._reward = nn.Sequential(
|
||||
NormedLinear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
|
||||
# Define policy head
|
||||
self._pi = nn.Sequential(
|
||||
NormedLinear(config.latent_dim, config.mlp_dim),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, 2 * config.output_shapes["action"][0]),
|
||||
)
|
||||
|
||||
# Define ensemble of Q functions
|
||||
self._Qs = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
NormedLinear(
|
||||
config.latent_dim + config.output_shapes["action"][0],
|
||||
config.mlp_dim,
|
||||
dropout=config.dropout,
|
||||
),
|
||||
NormedLinear(config.mlp_dim, config.mlp_dim),
|
||||
nn.Linear(config.mlp_dim, max(config.num_bins, 1)),
|
||||
)
|
||||
for _ in range(config.q_ensemble_size)
|
||||
]
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
self._target_Qs = deepcopy(self._Qs).requires_grad_(False)
|
||||
|
||||
self.log_std_min = torch.tensor(config.log_std_min)
|
||||
self.log_std_dif = torch.tensor(config.log_std_max) - self.log_std_min
|
||||
|
||||
self.bins = torch.linspace(config.vmin, config.vmax, config.num_bins)
|
||||
self.config.bin_size = (config.vmax - config.vmin) / (config.num_bins - 1)
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize model weights.
|
||||
Custom weight initializations proposed in TD-MPC2.
|
||||
|
||||
"""
|
||||
|
||||
def _apply_fn(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.ParameterList):
|
||||
for i, p in enumerate(m):
|
||||
if p.dim() == 3: # Linear
|
||||
nn.init.trunc_normal_(p, std=0.02) # Weight
|
||||
nn.init.constant_(m[i + 1], 0) # Bias
|
||||
|
||||
self.apply(_apply_fn)
|
||||
|
||||
# initialize parameters of the
|
||||
for m in [self._reward, *self._Qs]:
|
||||
assert isinstance(
|
||||
m[-1], nn.Linear
|
||||
), "Sanity check. The last linear layer needs 0 initialization on weights."
|
||||
nn.init.zeros_(m[-1].weight)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
"""
|
||||
Overriding `to` method to also move additional tensors to device.
|
||||
"""
|
||||
super().to(*args, **kwargs)
|
||||
self.log_std_min = self.log_std_min.to(*args, **kwargs)
|
||||
self.log_std_dif = self.log_std_dif.to(*args, **kwargs)
|
||||
self.bins = self.bins.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def train(self, mode):
|
||||
super().train(mode)
|
||||
self._target_Qs.train(False)
|
||||
return self
|
||||
|
||||
def encode(self, obs: dict[str, Tensor]) -> Tensor:
|
||||
"""Encodes an observation into its latent representation."""
|
||||
return self._encoder(obs)
|
||||
|
||||
def latent_dynamics_and_reward(
|
||||
self, z: Tensor, a: Tensor, discretize_reward: bool = False
|
||||
) -> tuple[Tensor, Tensor, bool]:
|
||||
"""Predict the next state's latent representation and the reward given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- (*, latent_dim) tensor for the next state's latent representation.
|
||||
- (*,) tensor for the estimated reward.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
reward = self._reward(x).squeeze(-1)
|
||||
if discretize_reward:
|
||||
reward = two_hot_inv(reward, self.bins)
|
||||
return self._dynamics(x), reward
|
||||
|
||||
def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor:
|
||||
"""Predict the next state's latent representation given a current latent and action.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
Returns:
|
||||
(*, latent_dim) tensor for the next state's latent representation.
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
return self._dynamics(x)
|
||||
|
||||
def pi(self, z: Tensor) -> Tensor:
|
||||
"""Samples an action from the learned policy.
|
||||
|
||||
The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when
|
||||
generating rollouts for online training.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
std: The standard deviation of the injected noise.
|
||||
Returns:
|
||||
(*, action_dim) tensor for the sampled action.
|
||||
"""
|
||||
mu, log_std = self._pi(z).chunk(2, dim=-1)
|
||||
log_std = self.log_std_min + 0.5 * self.log_std_dif * (torch.tanh(log_std) + 1)
|
||||
eps = torch.randn_like(mu)
|
||||
|
||||
log_pi = gaussian_logprob(eps, log_std)
|
||||
pi = mu + eps * log_std.exp()
|
||||
mu, pi, log_pi = squash(mu, pi, log_pi)
|
||||
|
||||
return pi, mu, log_pi, log_std
|
||||
|
||||
def Qs(self, z: Tensor, a: Tensor, return_type: str = "min", target=False) -> Tensor: # noqa: N802
|
||||
"""Predict state-action value for all of the learned Q functions.
|
||||
|
||||
Args:
|
||||
z: (*, latent_dim) tensor for the current state's latent representation.
|
||||
a: (*, action_dim) tensor for the action to be applied.
|
||||
return_type: either 'min' or 'all' otherwise the average is returned
|
||||
Returns:
|
||||
(q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble or the average or min
|
||||
"""
|
||||
x = torch.cat([z, a], dim=-1)
|
||||
|
||||
if target:
|
||||
out = torch.stack([q(x).squeeze(-1) for q in self._target_Qs], dim=0)
|
||||
else:
|
||||
out = torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0)
|
||||
|
||||
if return_type == "all":
|
||||
return out
|
||||
|
||||
Q1, Q2 = out[np.random.choice(len(self._Qs), size=2, replace=False)]
|
||||
Q1, Q2 = two_hot_inv(Q1, self.bins), two_hot_inv(Q2, self.bins)
|
||||
return torch.min(Q1, Q2) if return_type == "min" else (Q1 + Q2) / 2
|
||||
|
||||
def update_target_Q(self):
|
||||
"""
|
||||
Soft-update target Q-networks using Polyak averaging.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
for p, p_target in zip(self._Qs.parameters(), self._target_Qs.parameters(), strict=False):
|
||||
p_target.data.lerp_(p.data, self.config.target_model_momentum)
|
||||
|
||||
|
||||
class TDMPC2ObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations."""
|
||||
|
||||
def __init__(self, config: TDMPC2Config):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
TODO(alexander-soare): The original work allows for multiple images by concatenating them along the
|
||||
channel dimension. Re-implement this capability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Define the observation encoder whether its pixels or states
|
||||
encoder_dict = {}
|
||||
for obs_key in config.input_shapes:
|
||||
if "observation.image" in config.input_shapes:
|
||||
encoder_module = nn.Sequential(
|
||||
nn.Conv2d(config.input_shapes[obs_key][0], config.image_encoder_hidden_dim, 7, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=1),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes[obs_key])
|
||||
with torch.inference_mode():
|
||||
out_shape = encoder_module(dummy_batch).shape[1:]
|
||||
encoder_module.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
NormedLinear(np.prod(out_shape), config.latent_dim, act=SimNorm(config.simnorm_dim)),
|
||||
)
|
||||
)
|
||||
|
||||
elif (
|
||||
"observation.state" in config.input_shapes
|
||||
or "observation.environment_state" in config.input_shapes
|
||||
):
|
||||
encoder_module = nn.ModuleList()
|
||||
encoder_module.append(
|
||||
NormedLinear(config.input_shapes[obs_key][0], config.state_encoder_hidden_dim)
|
||||
)
|
||||
assert config.num_enc_layers > 0
|
||||
for _ in range(config.num_enc_layers - 1):
|
||||
encoder_module.append(
|
||||
NormedLinear(config.state_encoder_hidden_dim, config.state_encoder_hidden_dim)
|
||||
)
|
||||
encoder_module.append(
|
||||
NormedLinear(
|
||||
config.state_encoder_hidden_dim, config.latent_dim, act=SimNorm(config.simnorm_dim)
|
||||
)
|
||||
)
|
||||
encoder_module = nn.Sequential(*encoder_module)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"No corresponding encoder module for key {obs_key}.")
|
||||
|
||||
encoder_dict[obs_key.replace(".", "")] = encoder_module
|
||||
|
||||
self.encoder = nn.ModuleDict(encoder_dict)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
for obs_key in self.config.input_shapes:
|
||||
if "observation.image" in obs_key:
|
||||
feat.append(
|
||||
flatten_forward_unflatten(self.encoder[obs_key.replace(".", "")], obs_dict[obs_key])
|
||||
)
|
||||
else:
|
||||
feat.append(self.encoder[obs_key.replace(".", "")](obs_dict[obs_key]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor:
|
||||
"""Randomly shifts images horizontally and vertically.
|
||||
|
||||
Adapted from https://github.com/facebookresearch/drqv2
|
||||
"""
|
||||
b, _, h, w = x.size()
|
||||
assert h == w, "non-square images not handled yet"
|
||||
pad = int(round(max_random_shift_ratio * h))
|
||||
x = F.pad(x, tuple([pad] * 4), "replicate")
|
||||
eps = 1.0 / (h + 2 * pad)
|
||||
arange = torch.linspace(
|
||||
-1.0 + eps,
|
||||
1.0 - eps,
|
||||
h + 2 * pad,
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)[:h]
|
||||
arange = einops.repeat(arange, "w -> h w 1", h=h)
|
||||
base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2)
|
||||
base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b)
|
||||
# A random shift in units of pixels and within the boundaries of the padding.
|
||||
shift = torch.randint(
|
||||
0,
|
||||
2 * pad + 1,
|
||||
size=(b, 1, 1, 2),
|
||||
device=x.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
shift *= 2.0 / (h + 2 * pad)
|
||||
grid = base_grid + shift
|
||||
return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False)
|
||||
|
||||
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally
|
||||
different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
|
||||
class RunningScale:
|
||||
"""Running trimmed scale estimator."""
|
||||
|
||||
def __init__(self, tau):
|
||||
self.tau = tau
|
||||
self._value = torch.ones(1, dtype=torch.float32, device=torch.device("cuda"))
|
||||
self._percentiles = torch.tensor([5, 95], dtype=torch.float32, device=torch.device("cuda"))
|
||||
|
||||
def state_dict(self):
|
||||
return dict(value=self._value, percentiles=self._percentiles)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self._value.data.copy_(state_dict["value"])
|
||||
self._percentiles.data.copy_(state_dict["percentiles"])
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self._value.cpu().item()
|
||||
|
||||
def _percentile(self, x):
|
||||
x_dtype, x_shape = x.dtype, x.shape
|
||||
x = x.view(x.shape[0], -1)
|
||||
in_sorted, _ = torch.sort(x, dim=0)
|
||||
positions = self._percentiles * (x.shape[0] - 1) / 100
|
||||
floored = torch.floor(positions)
|
||||
ceiled = floored + 1
|
||||
ceiled[ceiled > x.shape[0] - 1] = x.shape[0] - 1
|
||||
weight_ceiled = positions - floored
|
||||
weight_floored = 1.0 - weight_ceiled
|
||||
d0 = in_sorted[floored.long(), :] * weight_floored[:, None]
|
||||
d1 = in_sorted[ceiled.long(), :] * weight_ceiled[:, None]
|
||||
return (d0 + d1).view(-1, *x_shape[1:]).type(x_dtype)
|
||||
|
||||
def update(self, x):
|
||||
percentiles = self._percentile(x.detach())
|
||||
value = torch.clamp(percentiles[1] - percentiles[0], min=1.0)
|
||||
self._value.data.lerp_(value, self.tau)
|
||||
|
||||
def __call__(self, x, update=False):
|
||||
if update:
|
||||
self.update(x)
|
||||
return x * (1 / self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"RunningScale(S: {self.value})"
|
||||
164
lerobot/common/policies/tdmpc2/tdmpc2_utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from functorch import combine_state_for_ensemble
|
||||
|
||||
|
||||
class Ensemble(nn.Module):
|
||||
"""
|
||||
Vectorized ensemble of modules.
|
||||
"""
|
||||
|
||||
def __init__(self, modules, **kwargs):
|
||||
super().__init__()
|
||||
modules = nn.ModuleList(modules)
|
||||
fn, params, _ = combine_state_for_ensemble(modules)
|
||||
self.vmap = torch.vmap(fn, in_dims=(0, 0, None), randomness="different", **kwargs)
|
||||
self.params = nn.ParameterList([nn.Parameter(p) for p in params])
|
||||
self._repr = str(modules)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.vmap([p for p in self.params], (), *args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "Vectorized " + self._repr
|
||||
|
||||
|
||||
class SimNorm(nn.Module):
|
||||
"""
|
||||
Simplicial normalization.
|
||||
Adapted from https://arxiv.org/abs/2204.00616.
|
||||
"""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
shp = x.shape
|
||||
x = x.view(*shp[:-1], -1, self.dim)
|
||||
x = F.softmax(x, dim=-1)
|
||||
return x.view(*shp)
|
||||
|
||||
def __repr__(self):
|
||||
return f"SimNorm(dim={self.dim})"
|
||||
|
||||
|
||||
class NormedLinear(nn.Linear):
|
||||
"""
|
||||
Linear layer with LayerNorm, activation, and optionally dropout.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, dropout=0.0, act=nn.Mish(inplace=True), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ln = nn.LayerNorm(self.out_features)
|
||||
self.act = act
|
||||
self.dropout = nn.Dropout(dropout, inplace=True) if dropout else None
|
||||
|
||||
def forward(self, x):
|
||||
x = super().forward(x)
|
||||
if self.dropout:
|
||||
x = self.dropout(x)
|
||||
return self.act(self.ln(x))
|
||||
|
||||
def __repr__(self):
|
||||
repr_dropout = f", dropout={self.dropout.p}" if self.dropout else ""
|
||||
return (
|
||||
f"NormedLinear(in_features={self.in_features}, "
|
||||
f"out_features={self.out_features}, "
|
||||
f"bias={self.bias is not None}{repr_dropout}, "
|
||||
f"act={self.act.__class__.__name__})"
|
||||
)
|
||||
|
||||
|
||||
def soft_cross_entropy(pred, target, cfg):
|
||||
"""Computes the cross entropy loss between predictions and soft targets."""
|
||||
pred = F.log_softmax(pred, dim=-1)
|
||||
target = two_hot(target, cfg)
|
||||
return -(target * pred).sum(-1, keepdim=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def log_std(x, low, dif):
|
||||
return low + 0.5 * dif * (torch.tanh(x) + 1)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_residual(eps, log_std):
|
||||
return -0.5 * eps.pow(2) - log_std
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _gaussian_logprob(residual):
|
||||
return residual - 0.5 * torch.log(2 * torch.pi)
|
||||
|
||||
|
||||
def gaussian_logprob(eps, log_std, size=None):
|
||||
"""Compute Gaussian log probability."""
|
||||
residual = _gaussian_residual(eps, log_std).sum(-1, keepdim=True)
|
||||
if size is None:
|
||||
size = eps.size(-1)
|
||||
return _gaussian_logprob(residual) * size
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _squash(pi):
|
||||
return torch.log(F.relu(1 - pi.pow(2)) + 1e-6)
|
||||
|
||||
|
||||
def squash(mu, pi, log_pi):
|
||||
"""Apply squashing function."""
|
||||
mu = torch.tanh(mu)
|
||||
pi = torch.tanh(pi)
|
||||
log_pi -= _squash(pi).sum(-1, keepdim=True)
|
||||
return mu, pi, log_pi
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symlog(x):
|
||||
"""
|
||||
Symmetric logarithmic function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * torch.log(1 + torch.abs(x))
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def symexp(x):
|
||||
"""
|
||||
Symmetric exponential function.
|
||||
Adapted from https://github.com/danijar/dreamerv3.
|
||||
"""
|
||||
return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
|
||||
|
||||
|
||||
def two_hot(x, cfg):
|
||||
"""Converts a batch of scalars to soft two-hot encoded targets for discrete regression."""
|
||||
|
||||
# x shape [horizon, num_features]
|
||||
if cfg.num_bins == 0:
|
||||
return x
|
||||
elif cfg.num_bins == 1:
|
||||
return symlog(x)
|
||||
x = torch.clamp(symlog(x), cfg.vmin, cfg.vmax)
|
||||
bin_idx = torch.floor((x - cfg.vmin) / cfg.bin_size).long() # shape [num_features]
|
||||
bin_offset = ((x - cfg.vmin) / cfg.bin_size - bin_idx.float()).unsqueeze(-1) # shape [num_features , 1]
|
||||
soft_two_hot = torch.zeros(
|
||||
*x.shape, cfg.num_bins, device=x.device
|
||||
) # shape [horizon, num_features, num_bins]
|
||||
soft_two_hot.scatter_(2, bin_idx.unsqueeze(-1), 1 - bin_offset)
|
||||
soft_two_hot.scatter_(2, (bin_idx.unsqueeze(-1) + 1) % cfg.num_bins, bin_offset)
|
||||
return soft_two_hot
|
||||
|
||||
|
||||
def two_hot_inv(x, bins):
|
||||
"""Converts a batch of soft two-hot encoded vectors to scalars."""
|
||||
num_bins = bins.shape[0]
|
||||
if num_bins == 0:
|
||||
return x
|
||||
elif num_bins == 1:
|
||||
return symexp(x)
|
||||
|
||||
x = F.softmax(x, dim=-1)
|
||||
x = torch.sum(x * bins, dim=-1, keepdim=True)
|
||||
return symexp(x)
|
||||
@@ -168,7 +168,6 @@ class IntelRealSenseCameraConfig:
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
use_depth: bool = False
|
||||
force_hardware_reset: bool = True
|
||||
rotation: int | None = None
|
||||
@@ -180,8 +179,6 @@ class IntelRealSenseCameraConfig:
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None
|
||||
at_least_one_is_none = self.fps is None or self.width is None or self.height is None
|
||||
if at_least_one_is_not_none and at_least_one_is_none:
|
||||
@@ -203,8 +200,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py \
|
||||
--images-dir outputs/images_from_intelrealsense_cameras
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
|
||||
```
|
||||
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
@@ -258,7 +254,6 @@ class IntelRealSenseCamera:
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.use_depth = config.use_depth
|
||||
self.force_hardware_reset = config.force_hardware_reset
|
||||
|
||||
@@ -192,7 +192,6 @@ class OpenCVCameraConfig:
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
color_mode: str = "rgb"
|
||||
channels: int | None = None
|
||||
rotation: int | None = None
|
||||
mock: bool = False
|
||||
|
||||
@@ -202,8 +201,6 @@ class OpenCVCameraConfig:
|
||||
f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
self.channels = 3
|
||||
|
||||
if self.rotation not in [-90, None, 90, 180]:
|
||||
raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})")
|
||||
|
||||
@@ -219,8 +216,7 @@ class OpenCVCamera:
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py \
|
||||
--images-dir outputs/images_from_opencv_cameras
|
||||
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
@@ -272,7 +268,6 @@ class OpenCVCamera:
|
||||
self.fps = config.fps
|
||||
self.width = config.width
|
||||
self.height = config.height
|
||||
self.channels = config.channels
|
||||
self.color_mode = config.color_mode
|
||||
self.mock = config.mock
|
||||
|
||||
@@ -328,7 +323,7 @@ class OpenCVCamera:
|
||||
if self.camera_index not in available_cam_ids:
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
||||
"To find the camera index you should use, run `python lerobot/lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
|
||||
|
||||
@@ -15,8 +15,7 @@ import torch
|
||||
import tqdm
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, safe_stop_image_writer
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
@@ -228,7 +227,7 @@ def control_loop(
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
dataset=None,
|
||||
events=None,
|
||||
policy=None,
|
||||
device=None,
|
||||
@@ -248,7 +247,7 @@ def control_loop(
|
||||
if teleoperate and policy is not None:
|
||||
raise ValueError("When `teleoperate` is True, `policy` should be None.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
if dataset is not None and fps is not None and dataset["fps"] != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
@@ -269,8 +268,7 @@ def control_loop(
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
dataset.add_frame(frame)
|
||||
add_frame(dataset, observation, action)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
@@ -298,16 +298,6 @@ class FeetechMotorsBus:
|
||||
self.logs = {}
|
||||
|
||||
self.track_positions = {}
|
||||
self.present_pos = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
"below_zero": [None] * len(self.motor_names),
|
||||
"above_max": [None] * len(self.motor_names),
|
||||
}
|
||||
self.goal_pos = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
"below_zero": [None] * len(self.motor_names),
|
||||
"above_max": [None] * len(self.motor_names),
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
|
||||
@@ -64,7 +64,7 @@ def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=N
|
||||
# print(f"{present_voltage=}")
|
||||
# print(f"{present_temperature=}")
|
||||
|
||||
if present_speed == 0 and present_current > 50:
|
||||
if present_speed == 0 and present_current > 40:
|
||||
count += 1
|
||||
if count > 100 or present_current > 300:
|
||||
return present_pos
|
||||
@@ -306,16 +306,16 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
calib = {}
|
||||
|
||||
print("Calibrate shoulder_pan")
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan", load_threshold=350, count_threshold=200)
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate gripper")
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True, count_threshold=200)
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_flex")
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True, count_threshold=200)
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
|
||||
|
||||
wr_pos = arm.read("Present_Position", "wrist_roll")
|
||||
@@ -329,7 +329,7 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True, count_threshold=200)
|
||||
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
|
||||
calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
|
||||
@@ -348,7 +348,6 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
arm,
|
||||
"elbow_flex",
|
||||
invert_drive_mode=True,
|
||||
count_threshold=200,
|
||||
in_between_move_hook=in_between_move_elbow_flex_hook,
|
||||
)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
|
||||
@@ -226,13 +226,6 @@ class ManipulatorRobot:
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
|
||||
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
|
||||
self.names = {
|
||||
"action": action_names,
|
||||
"observation.state": state_names,
|
||||
}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
@@ -271,7 +264,7 @@ class ManipulatorRobot:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
if self.robot_type in ["koch", "aloha"]:
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
@@ -286,7 +279,7 @@ class ManipulatorRobot:
|
||||
self.activate_calibration()
|
||||
|
||||
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
|
||||
if self.robot_type == "koch":
|
||||
if self.robot_type in ["koch", "koch_bimanual"]:
|
||||
self.set_koch_robot_preset()
|
||||
elif self.robot_type == "aloha":
|
||||
self.set_aloha_robot_preset()
|
||||
@@ -299,7 +292,7 @@ class ManipulatorRobot:
|
||||
self.follower_arms[name].write("Torque_Enable", 1)
|
||||
|
||||
if self.config.gripper_open_degree is not None:
|
||||
if self.robot_type in ["aloha", "so100", "moss"]:
|
||||
if self.robot_type not in ["koch", "koch_bimanual"]:
|
||||
raise NotImplementedError(
|
||||
f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open."
|
||||
)
|
||||
@@ -335,26 +328,20 @@ class ManipulatorRobot:
|
||||
with open(arm_calib_path) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "aloha"]:
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_auto_calibration,
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
# TODO(rcadene): better way to handle mocking + test run_arm_auto_calibration
|
||||
if arm_type == "leader" or arm.mock:
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
elif arm_type == "follower":
|
||||
calibration = run_arm_auto_calibration(arm, self.robot_type, name, arm_type)
|
||||
else:
|
||||
raise ValueError(arm_type)
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
10
lerobot/configs/env/moss_real.yaml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
10
lerobot/configs/env/so100_real.yaml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
102
lerobot/configs/policy/act_moss_real.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/moss_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
102
lerobot/configs/policy/act_so100_real.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/so100_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,5 +1,5 @@
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: koch
|
||||
robot_type: koch_bimanual
|
||||
calibration_dir: .cache/calibration/koch_bimanual
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
|
||||
@@ -106,6 +106,12 @@ from typing import List
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.populate_dataset import (
|
||||
create_lerobot_dataset,
|
||||
delete_current_episode,
|
||||
init_dataset,
|
||||
save_current_episode,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
control_loop,
|
||||
has_method,
|
||||
@@ -192,24 +198,23 @@ def record(
|
||||
robot: Robot,
|
||||
root: str,
|
||||
repo_id: str,
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
reset_time_s: int | float = 5,
|
||||
num_episodes: int = 50,
|
||||
video: bool = True,
|
||||
run_compute_stats: bool = True,
|
||||
push_to_hub: bool = True,
|
||||
num_image_writer_processes: int = 0,
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
tags: str = None,
|
||||
force_override: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
warmup_time_s=2,
|
||||
episode_time_s=10,
|
||||
reset_time_s=5,
|
||||
num_episodes=50,
|
||||
video=True,
|
||||
run_compute_stats=True,
|
||||
push_to_hub=True,
|
||||
tags=None,
|
||||
num_image_writer_processes=0,
|
||||
num_image_writer_threads_per_camera=4,
|
||||
force_override=False,
|
||||
display_cameras=True,
|
||||
play_sounds=True,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
events = None
|
||||
@@ -217,11 +222,6 @@ def record(
|
||||
device = None
|
||||
use_amp = None
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
else:
|
||||
raise NotImplementedError("Only single-task recording is supported for now")
|
||||
|
||||
# Load pretrained policy
|
||||
if pretrained_policy_name_or_path is not None:
|
||||
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
|
||||
@@ -236,14 +236,15 @@ def record(
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
dataset = init_dataset(
|
||||
repo_id,
|
||||
root,
|
||||
force_override,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||
use_videos=video,
|
||||
video,
|
||||
write_images=robot.has_camera,
|
||||
num_image_writer_processes=num_image_writer_processes,
|
||||
num_image_writer_threads=num_image_writer_threads_per_camera * robot.num_cameras,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
@@ -262,17 +263,11 @@ def record(
|
||||
if has_method(robot, "teleop_safety_stop"):
|
||||
robot.teleop_safety_stop()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= num_episodes:
|
||||
if dataset["num_episodes"] >= num_episodes:
|
||||
break
|
||||
|
||||
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
|
||||
# input() messes with them.
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
episode_index = dataset.episode_buffer["episode_index"]
|
||||
episode_index = dataset["num_episodes"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
@@ -300,11 +295,11 @@ def record(
|
||||
log_say("Re-record episode", play_sounds)
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
delete_current_episode(dataset)
|
||||
continue
|
||||
|
||||
dataset.add_episode(task)
|
||||
recorded_episodes += 1
|
||||
# Increment by one dataset["current_episode_index"]
|
||||
save_current_episode(dataset)
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
@@ -312,47 +307,35 @@ def record(
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
if dataset.image_writer is not None:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
|
||||
log_say("Exiting", play_sounds)
|
||||
return dataset
|
||||
return lerobot_dataset
|
||||
|
||||
|
||||
@safe_disconnect
|
||||
def replay(
|
||||
robot: Robot,
|
||||
root: Path,
|
||||
repo_id: str,
|
||||
episode: int,
|
||||
fps: int | None = None,
|
||||
play_sounds: bool = True,
|
||||
local_files_only: bool = True,
|
||||
robot: Robot, episode: int, fps: int | None = None, root="data", repo_id="lerobot/debug", play_sounds=True
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
local_dir = Path(root) / repo_id
|
||||
if not local_dir.exists():
|
||||
raise ValueError(local_dir)
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
items = dataset.hf_dataset.select_columns("action")
|
||||
from_idx = dataset.episode_data_index["from"][episode].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode].item()
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_samples):
|
||||
for idx in range(from_idx, to_idx):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
action = items[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
@@ -401,21 +384,9 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser_record = subparsers.add_parser("record", parents=[base_parser])
|
||||
task_args = parser_record.add_mutually_exclusive_group(required=True)
|
||||
parser_record.add_argument(
|
||||
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
|
||||
)
|
||||
task_args.add_argument(
|
||||
"--single-task",
|
||||
type=str,
|
||||
help="A short but accurate description of the task performed during the recording.",
|
||||
)
|
||||
# TODO(aliberts): add multi-task support
|
||||
# task_args.add_argument(
|
||||
# "--multi-task",
|
||||
# type=int,
|
||||
# help="You will need to enter the task performed at the start of each episode.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--root",
|
||||
type=Path,
|
||||
|
||||
@@ -260,7 +260,7 @@ def push_dataset_to_hub(
|
||||
episode_index = 0
|
||||
tests_videos_dir = tests_data_dir / repo_id / "videos"
|
||||
tests_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||
for key in lerobot_dataset.camera_keys:
|
||||
for key in lerobot_dataset.video_frame_keys:
|
||||
fname = f"{key}_episode_{episode_index:06d}.mp4"
|
||||
shutil.copy(videos_dir / fname, tests_videos_dir / fname)
|
||||
|
||||
|
||||
@@ -93,6 +93,18 @@ def make_optimizer_and_scheduler(cfg, policy):
|
||||
elif policy.name == "tdmpc":
|
||||
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif policy.name == "tdmpc2":
|
||||
params_group = [
|
||||
{"params": policy.model._encoder.parameters(), "lr": cfg.training.lr * cfg.training.enc_lr_scale},
|
||||
{"params": policy.model._dynamics.parameters()},
|
||||
{"params": policy.model._reward.parameters()},
|
||||
{"params": policy.model._Qs.parameters()},
|
||||
{"params": policy.model._pi.parameters(), "eps": 1e-5},
|
||||
]
|
||||
optimizer = torch.optim.Adam(params_group, lr=cfg.training.lr)
|
||||
lr_scheduler = None
|
||||
|
||||
elif cfg.policy.name == "vqbet":
|
||||
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
|
||||
|
||||
|
||||
@@ -97,13 +97,14 @@ def run_server(
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
|
||||
tasks = dataset.episode_dicts[episode_id]["tasks"]
|
||||
video_paths = get_episode_video_paths(dataset, episode_id)
|
||||
language_instruction = get_episode_language_instruction(dataset, episode_id)
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
{"url": url_for("static", filename=video_path), "filename": Path(video_path).name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
if language_instruction:
|
||||
videos_info[0]["language_instruction"] = language_instruction
|
||||
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
@@ -136,10 +137,10 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.shapes["observation.state"]
|
||||
dim_state = len(dataset.hf_dataset["observation.state"][0])
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.shapes["action"]
|
||||
dim_action = len(dataset.hf_dataset["action"][0])
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
@@ -170,7 +171,8 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.video_frame_keys
|
||||
]
|
||||
|
||||
|
||||
@@ -202,8 +204,8 @@ def visualize_dataset_html(
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if len(dataset.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
|
||||
if not dataset.video:
|
||||
raise NotImplementedError(f"Image datasets ({dataset.video=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
@@ -223,7 +225,7 @@ def visualize_dataset_html(
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
ln_videos_dir.symlink_to(dataset.videos_dir.resolve())
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
|
||||
BIN
media/gym/aloha_act.gif
Normal file
|
After Width: | Height: | Size: 2.9 MiB |
BIN
media/gym/pusht_diffusion.gif
Normal file
|
After Width: | Height: | Size: 185 KiB |
BIN
media/gym/simxarm_tdmpc.gif
Normal file
|
After Width: | Height: | Size: 464 KiB |
BIN
media/moss/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 153 KiB |
BIN
media/moss/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 208 KiB |
BIN
media/moss/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 296 KiB |
BIN
media/so100/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 145 KiB |
BIN
media/so100/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 95 KiB |
BIN
media/so100/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
media/so100/leader_follower.webp
Normal file
|
After Width: | Height: | Size: 117 KiB |
1481
poetry.lock
generated
@@ -43,8 +43,9 @@ opencv-python = ">=4.9.0"
|
||||
diffusers = ">=0.27.2"
|
||||
torchvision = ">=0.17.1"
|
||||
h5py = ">=3.10.0"
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.2"}
|
||||
gymnasium = "==0.29.1" # TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
huggingface-hub = {extras = ["hf-transfer", "cli"], version = ">=0.25.0"}
|
||||
# TODO(rcadene, aliberts): Make gym 1.0.0 work
|
||||
gymnasium = "==0.29.1"
|
||||
cmake = ">=3.29.0.1"
|
||||
gym-dora = { git = "https://github.com/dora-rs/dora-lerobot.git", subdirectory = "gym_dora", optional = true }
|
||||
gym-pusht = { version = ">=0.1.5", optional = true}
|
||||
@@ -70,7 +71,6 @@ pyrealsense2 = {version = ">=2.55.1.6486", markers = "sys_platform != 'darwin'",
|
||||
pyrender = {git = "https://github.com/mmatl/pyrender.git", markers = "sys_platform == 'linux'", optional = true}
|
||||
hello-robot-stretch-body = {version = ">=0.7.27", markers = "sys_platform == 'linux'", optional = true}
|
||||
pyserial = {version = ">=3.5", optional = true}
|
||||
jsonlines = ">=4.0.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
||||
@@ -29,6 +29,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -92,9 +93,8 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides.append(f"calibration_dir={calibration_dir}")
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
record(
|
||||
@@ -102,7 +102,6 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
fps=30,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
@@ -133,18 +132,17 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = tmpdir / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = tmpdir / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=5,
|
||||
warmup_time_s=0.5,
|
||||
fps=1,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
num_episodes=2,
|
||||
@@ -155,18 +153,24 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.total_episodes == 2
|
||||
assert len(dataset) == 10
|
||||
assert dataset.num_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=5, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual", "so100", "moss"]:
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
elif robot_type == "so100":
|
||||
env_name = "so100_real"
|
||||
policy_name = "act_so100_real"
|
||||
elif robot_type == "moss":
|
||||
env_name = "moss_real"
|
||||
policy_name = "act_moss_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
@@ -221,14 +225,10 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
else:
|
||||
num_image_writer_processes = 0
|
||||
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
eval_root = tmpdir / "data" / eval_repo_id
|
||||
|
||||
dataset = record(
|
||||
record(
|
||||
robot,
|
||||
eval_root,
|
||||
root,
|
||||
eval_repo_id,
|
||||
single_task,
|
||||
pretrained_policy_name_or_path,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
@@ -265,15 +265,13 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -286,33 +284,32 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
# init_dataset_return_value = {}
|
||||
init_dataset_return_value = {}
|
||||
|
||||
# def wrapped_init_dataset(*args, **kwargs):
|
||||
# nonlocal init_dataset_return_value
|
||||
# init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||
# return init_dataset_return_value
|
||||
def wrapped_init_dataset(*args, **kwargs):
|
||||
nonlocal init_dataset_return_value
|
||||
init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||
return init_dataset_return_value
|
||||
|
||||
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
# assert (
|
||||
# init_dataset_return_value["num_episodes"] == 2
|
||||
# ), "`init_dataset` should load the previous episode"
|
||||
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
assert (
|
||||
init_dataset_return_value["num_episodes"] == 2
|
||||
), "`init_dataset` should load the previous episode"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@@ -331,22 +328,23 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = True
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -360,6 +358,7 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@@ -379,22 +378,23 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=2,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -407,6 +407,7 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@@ -428,22 +429,23 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = True
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -457,4 +459,5 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
@@ -42,28 +42,7 @@ from lerobot.common.datasets.utils import (
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||
|
||||
# TODO(aliberts): create proper test repo
|
||||
TEST_REPO_ID = "aliberts/koch_tutorial"
|
||||
|
||||
|
||||
def test_same_attributes_defined():
|
||||
# TODO(aliberts): test with keys, shapes, names etc. provided instead of robot
|
||||
robot = make_robot("koch", mock=True)
|
||||
|
||||
# Instantiate both ways
|
||||
dataset_init = LeRobotDataset(repo_id=TEST_REPO_ID)
|
||||
dataset_create = LeRobotDataset.create(repo_id=TEST_REPO_ID, fps=30, robot=robot)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init._hub_version
|
||||
_ = dataset_create._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
@@ -39,7 +39,6 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): add compatibility with other robots
|
||||
|
||||
robot_kwargs = {"robot_type": robot_type}
|
||||
|
||||
if robot_type == "aloha" and mock:
|
||||
|
||||