Compare commits
74 Commits
user/rcade
...
user/miche
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
16b905ee67 | ||
|
|
44d96a0811 | ||
|
|
4488e55e94 | ||
|
|
2945dca65b | ||
|
|
691d39aa32 | ||
|
|
56c01a2944 | ||
|
|
49bdcc094c | ||
|
|
96c7052777 | ||
|
|
6ad84a6561 | ||
|
|
3b5af7eb38 | ||
|
|
23f6c875b5 | ||
|
|
975c1c25c3 | ||
|
|
20f466768e | ||
|
|
8af693548e | ||
|
|
f56d769dfb | ||
|
|
36b9b60a0e | ||
|
|
93d9bf83c2 | ||
|
|
37da50b573 | ||
|
|
c6ad495176 | ||
|
|
f43e5d07f5 | ||
|
|
9ee8711504 | ||
|
|
6203641710 | ||
|
|
1f13bda25b | ||
|
|
acae4b49d2 | ||
|
|
c72ad49c9b | ||
|
|
eda02fade5 | ||
|
|
8546358bc5 | ||
|
|
a91b7c6163 | ||
|
|
4d15861872 | ||
|
|
f3630ad910 | ||
|
|
963738d983 | ||
|
|
aed9f4036a | ||
|
|
757ea175d3 | ||
|
|
b69a132737 | ||
|
|
7b159a6b22 | ||
|
|
74270c8c91 | ||
|
|
a6762ec316 | ||
|
|
fde29e0167 | ||
|
|
56e4603d5b | ||
|
|
f6c90ca35f | ||
|
|
c2d6fb6119 | ||
|
|
95a4b59b5b | ||
|
|
16103cb8b7 | ||
|
|
e4ba084e25 | ||
|
|
ac79e8cb36 | ||
|
|
df2cb51364 | ||
|
|
7a342db9c4 | ||
|
|
6b2ec1ed77 | ||
|
|
375abd3020 | ||
|
|
293bdc7f67 | ||
|
|
79d114cc1f | ||
|
|
2650872b76 | ||
|
|
cd1509d805 | ||
|
|
5ea7c78237 | ||
|
|
443a9eec88 | ||
|
|
ab23a4fd27 | ||
|
|
1267c3e955 | ||
|
|
e0df56de62 | ||
|
|
c70b8d0abc | ||
|
|
e1845d4dcc | ||
|
|
e69f0c5059 | ||
|
|
ff84024ee9 | ||
|
|
538455a965 | ||
|
|
ee51f54cb5 | ||
|
|
fee5fa5c2e | ||
|
|
8d57093ce7 | ||
|
|
4c22de20a6 | ||
|
|
51e87f6f97 | ||
|
|
172809a502 | ||
|
|
55e4ff6742 | ||
|
|
df3d2ec5df | ||
|
|
e210d795de | ||
|
|
07e8716315 | ||
|
|
18ffa4248b |
2
.github/workflows/test.yml
vendored
@@ -103,7 +103,7 @@ jobs:
|
||||
-W ignore::UserWarning:gymnasium.utils.env_checker:247 \
|
||||
&& rm -rf tests/outputs outputs
|
||||
|
||||
|
||||
# TODO(aliberts, rcadene): redesign after v2 migration / removing hydra
|
||||
end-to-end:
|
||||
name: End-to-end
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
18
README.md
@@ -23,15 +23,15 @@
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">Hot new tutorial: Getting started with real-world robots</a></p>
|
||||
<p><a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">New robot in town: SO-100</a></p>
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<img src="media/tutorial/koch_v1_1_leader_follower.webp?raw=true" alt="Koch v1.1 leader and follower arms" title="Koch v1.1 leader and follower arms" width="50%">
|
||||
<p>We just dropped an in-depth tutorial on how to build your own robot!</p>
|
||||
<img src="media/so100/leader_follower.webp?raw=true" alt="SO-100 leader and follower arms" title="SO-100 leader and follower arms" width="50%">
|
||||
<p>We just added a new tutorial on how to build a more affordable robot, at the price of $110 per arm!</p>
|
||||
<p>Teach it new skills by showing it a few moves with just a laptop.</p>
|
||||
<p>Then watch your homemade robot act autonomously 🤯</p>
|
||||
<p>For more info, see <a href="https://x.com/RemiCadene/status/1825455895561859185">our thread on X</a> or <a href="https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md">our tutorial page</a>.</p>
|
||||
<p>Follow the link to the <a href="https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md">full tutorial for SO-100</a>.</p>
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
@@ -55,9 +55,9 @@
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="http://remicadene.com/assets/gif/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="http://remicadene.com/assets/gif/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
<td><img src="media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
@@ -144,7 +144,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically download data from the Hugging Face hub.
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
```bash
|
||||
@@ -280,7 +280,7 @@ To use wandb for logging training and evaluation curves, make sure you've run `w
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ def benchmark_encoding_decoding(
|
||||
)
|
||||
|
||||
ep_num_images = dataset.episode_data_index["to"][0].item()
|
||||
width, height = tuple(dataset[0][dataset.camera_keys[0]].shape[-2:])
|
||||
width, height = tuple(dataset[0][dataset.meta.camera_keys[0]].shape[-2:])
|
||||
num_pixels = width * height
|
||||
video_size_bytes = video_path.stat().st_size
|
||||
images_size_bytes = get_directory_size(imgs_dir)
|
||||
|
||||
@@ -4,7 +4,7 @@ This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's install LeRobot. We will next provide a tutorial for assembly.
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
@@ -45,12 +45,46 @@ conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Run this script two times to find the ports (e.g. "/dev/tty.usbmodem58760432961") of your motor buses:
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Then plug your first motor, corresponding to "shoulder_pan" and run this script to set its ID to 1 and set its present position and offset to ~2048 (useful for calibration).
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
@@ -60,7 +94,9 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Then unplug your motor and plug the second motor, corresponding to "shoulder lift", and set its ID to 2.
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
@@ -70,23 +106,57 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until the gripper with ID 6. Do the same for the motors of the leader arm, starting for ID 1 up to 6.
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
TODO
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/follower_zero.webp?raw=true" alt="SO-100 follower arm zero position" title="SO-100 follower arm zero position" style="width:100%;"> | <img src="../media/so100/follower_rotated.webp?raw=true" alt="SO-100 follower arm rotated position" title="SO-100 follower arm rotated position" style="width:100%;"> | <img src="../media/so100/follower_rest.webp?raw=true" alt="SO-100 follower arm rest position" title="SO-100 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras'
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/so100/leader_zero.webp?raw=true" alt="SO-100 leader arm zero position" title="SO-100 leader arm zero position" style="width:100%;"> | <img src="../media/so100/leader_rotated.webp?raw=true" alt="SO-100 leader arm rotated position" title="SO-100 leader arm rotated position" style="width:100%;"> | <img src="../media/so100/leader_rest.webp?raw=true" alt="SO-100 leader arm rest position" title="SO-100 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
Without displaying the cameras:
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml \
|
||||
@@ -94,7 +164,9 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
With displaying the cameras:
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
@@ -102,7 +174,7 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with so100.
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
@@ -203,6 +275,6 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explaination.
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel `#so100-arm`.
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039).
|
||||
|
||||
280
examples/11_use_moss.md
Normal file
@@ -0,0 +1,280 @@
|
||||
This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-robot-arms) with LeRobot.
|
||||
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10 && conda activate lerobot
|
||||
```
|
||||
|
||||
4. Clone LeRobot:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
```
|
||||
|
||||
5. Install LeRobot with dependencies for the feetech motors:
|
||||
```bash
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## Configure the motors
|
||||
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below.
|
||||
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect leader arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751']
|
||||
Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
|
||||
[...Disconnect follower arm and press Enter...]
|
||||
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
--port /dev/tty.usbmodem58760432961 \
|
||||
--brand feetech \
|
||||
--model sts3215 \
|
||||
--baudrate 1000000 \
|
||||
--ID 2
|
||||
```
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). For Moss v1, you need to align the holes on the motor horn to the motor spline to be approximately 3, 6, 9 and 12 o'clock.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Moss v1 robot to work on another.
|
||||
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/follower_zero.webp?raw=true" alt="Moss v1 follower arm zero position" title="Moss v1 follower arm zero position" style="width:100%;"> | <img src="../media/moss/follower_rotated.webp?raw=true" alt="Moss v1 follower arm rotated position" title="Moss v1 follower arm rotated position" style="width:100%;"> | <img src="../media/moss/follower_rest.webp?raw=true" alt="Moss v1 follower arm rest position" title="Moss v1 follower arm rest position" style="width:100%;"> |
|
||||
|
||||
Make sure both arms are connected and run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
| <img src="../media/moss/leader_zero.webp?raw=true" alt="Moss v1 leader arm zero position" title="Moss v1 leader arm zero position" style="width:100%;"> | <img src="../media/moss/leader_rotated.webp?raw=true" alt="Moss v1 leader arm rotated position" title="Moss v1 leader arm rotated position" style="width:100%;"> | <img src="../media/moss/leader_rest.webp?raw=true" alt="Moss v1 leader arm rest position" title="Moss v1 leader arm rest position" style="width:100%;"> |
|
||||
|
||||
Run this script to launch manual calibration:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--robot-overrides '~cameras' \
|
||||
--display-cameras 0
|
||||
```
|
||||
|
||||
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/moss.yaml
|
||||
```
|
||||
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with Moss v1.
|
||||
|
||||
If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face repository name in a variable to run these commands:
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
Record 2 episodes and upload your dataset to the hub:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--tags moss tutorial \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 2 \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
echo ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
If you didn't upload with `--push-to-hub 0`, you can also visualize it locally with:
|
||||
```bash
|
||||
python lerobot/scripts/visualize_dataset_html.py \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test
|
||||
```
|
||||
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/control_robot.py replay \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/moss_test \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
DATA_DIR=data python lerobot/scripts/train.py \
|
||||
dataset_repo_id=${HF_USER}/moss_test \
|
||||
policy=act_moss_real \
|
||||
env=moss_real \
|
||||
hydra.run.dir=outputs/train/act_moss_test \
|
||||
hydra.job.name=act_moss_test \
|
||||
device=cuda \
|
||||
wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain it:
|
||||
1. We provided the dataset as argument with `dataset_repo_id=${HF_USER}/moss_test`.
|
||||
2. We provided the policy with `policy=act_moss_real`. This loads configurations from [`lerobot/configs/policy/act_moss_real.yaml`](../lerobot/configs/policy/act_moss_real.yaml). Importantly, this policy uses 2 cameras as input `laptop`, `phone`.
|
||||
3. We provided an environment as argument with `env=moss_real`. This loads configurations from [`lerobot/configs/env/moss_real.yaml`](../lerobot/configs/env/moss_real.yaml).
|
||||
4. We provided `device=cuda` since we are training on a Nvidia GPU, but you can also use `device=mps` if you are using a Mac with Apple silicon, or `device=cpu` otherwise.
|
||||
5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
6. We added `DATA_DIR=data` to access your dataset stored in your local `data` directory. If you dont provide `DATA_DIR`, your dataset will be downloaded from Hugging Face hub to your cache folder `$HOME/.cache/hugginface`. In future versions of `lerobot`, both directories will be in sync.
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`.
|
||||
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py record \
|
||||
--robot-path lerobot/configs/robot/moss.yaml \
|
||||
--fps 30 \
|
||||
--root data \
|
||||
--repo-id ${HF_USER}/eval_act_moss_test \
|
||||
--tags moss tutorial eval \
|
||||
--warmup-time-s 5 \
|
||||
--episode-time-s 40 \
|
||||
--reset-time-s 10 \
|
||||
--num-episodes 10 \
|
||||
-p outputs/train/act_moss_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
As you can see, it's almost the same command as previously used to record your training dataset. Two things changed:
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_moss_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_moss_test`).
|
||||
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
If you have any question or need help, please reach out on Discord in the channel [`#moss-arm`](https://discord.com/channels/1216765309076115607/1275374638985252925).
|
||||
@@ -3,78 +3,120 @@ This script demonstrates the use of `LeRobotDataset` class for handling and proc
|
||||
It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
||||
|
||||
Features included in this script:
|
||||
- Loading a dataset and accessing its properties.
|
||||
- Filtering data by episode number.
|
||||
- Converting tensor data for visualization.
|
||||
- Saving video files from dataset frames.
|
||||
- Viewing a dataset's metadata and exploring its properties.
|
||||
- Loading an existing dataset from the hub or a subset of it.
|
||||
- Accessing frames by episode number.
|
||||
- Using advanced dataset features like timestamp-based frame selection.
|
||||
- Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
||||
|
||||
The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
|
||||
import imageio
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
import lerobot
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# Let's take one for this example
|
||||
repo_id = "lerobot/pusht"
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# You can easily load a dataset from a Hugging Face repository
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets/index for more information).
|
||||
print(dataset)
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# And provides additional utilities for robotics and compatibility with Pytorch
|
||||
print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}")
|
||||
print(f"frames per second used during data collection: {dataset.fps=}")
|
||||
print(f"keys to access images from cameras: {dataset.camera_keys=}\n")
|
||||
|
||||
# Access frame indexes associated to first episode
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.episode_data_index["from"][episode_index].item()
|
||||
to_idx = dataset.episode_data_index["to"][episode_index].item()
|
||||
|
||||
# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset. Here we grab all the image frames.
|
||||
frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)]
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize
|
||||
# them, we convert to uint8 in range [0,255]
|
||||
frames = [(frame * 255).type(torch.uint8) for frame in frames]
|
||||
# and to channel last (h,w,c).
|
||||
frames = [frame.permute((1, 2, 0)).numpy() for frame in frames]
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# Finally, we save the frames to a mp4 video for visualization.
|
||||
Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True)
|
||||
imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps)
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
"observation.image": [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0],
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (8,c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64,c)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
# Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
||||
# PyTorch datasets.
|
||||
@@ -84,8 +126,9 @@ dataloader = torch.utils.data.DataLoader(
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32,8,c)
|
||||
print(f"{batch['action'].shape=}") # (32,64,c)
|
||||
print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w)
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 5, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
@@ -40,7 +40,7 @@ dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps)
|
||||
# For this example, no arguments need to be passed because the defaults are set up for PushT.
|
||||
# If you're doing something different, you will likely need to change at least some of the defaults.
|
||||
cfg = DiffusionConfig()
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.stats)
|
||||
policy = DiffusionPolicy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
This script demonstrates how to use torchvision's image transformation with LeRobotDataset for data
|
||||
augmentation purposes. The transformations are passed to the dataset as an argument upon creation, and
|
||||
transforms are applied to the observation images before they are returned in the dataset's __get_item__.
|
||||
transforms are applied to the observation images before they are returned in the dataset's __getitem__.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
@@ -20,7 +20,7 @@ dataset = LeRobotDataset(dataset_repo_id)
|
||||
first_idx = dataset.episode_data_index["from"][0].item()
|
||||
|
||||
# Get the frame corresponding to the first camera
|
||||
frame = dataset[first_idx][dataset.camera_keys[0]]
|
||||
frame = dataset[first_idx][dataset.meta.camera_keys[0]]
|
||||
|
||||
|
||||
# Define the transformations
|
||||
@@ -36,7 +36,7 @@ transforms = v2.Compose(
|
||||
transformed_dataset = LeRobotDataset(dataset_repo_id, image_transforms=transforms)
|
||||
|
||||
# Get a frame from the transformed dataset
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.camera_keys[0]]
|
||||
transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]]
|
||||
|
||||
# Create a directory to store output images
|
||||
output_dir = Path("outputs/image_transforms")
|
||||
|
||||
@@ -14,7 +14,7 @@ from pathlib import Path
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
|
||||
device = torch.device("cuda")
|
||||
@@ -41,26 +41,20 @@ delta_timestamps = {
|
||||
}
|
||||
|
||||
# Load the last 10% of episodes of the dataset as a validation set.
|
||||
# - Load full dataset
|
||||
full_dataset = LeRobotDataset("lerobot/pusht", split="train")
|
||||
# - Calculate train and val subsets
|
||||
num_train_episodes = math.floor(full_dataset.num_episodes * 90 / 100)
|
||||
num_val_episodes = full_dataset.num_episodes - num_train_episodes
|
||||
print(f"Number of episodes in full dataset: {full_dataset.num_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {num_train_episodes}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {num_val_episodes}")
|
||||
# - Get first frame index of the validation set
|
||||
first_val_frame_index = full_dataset.episode_data_index["from"][num_train_episodes].item()
|
||||
# - Load frames subset belonging to validation set using the `split` argument.
|
||||
# It utilizes the `datasets` library's syntax for slicing datasets.
|
||||
# For more information on the Slice API, please see:
|
||||
# https://huggingface.co/docs/datasets/v2.19.0/loading#slice-splits
|
||||
train_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[:{first_val_frame_index}]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
val_dataset = LeRobotDataset(
|
||||
"lerobot/pusht", split=f"train[{first_val_frame_index}:]", delta_timestamps=delta_timestamps
|
||||
)
|
||||
# - Load dataset metadata
|
||||
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
|
||||
# - Calculate train and val episodes
|
||||
total_episodes = dataset_metadata.total_episodes
|
||||
episodes = list(range(dataset_metadata.total_episodes))
|
||||
num_train_episodes = math.floor(total_episodes * 90 / 100)
|
||||
train_episodes = episodes[:num_train_episodes]
|
||||
val_episodes = episodes[num_train_episodes:]
|
||||
print(f"Number of episodes in full dataset: {total_episodes}")
|
||||
print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}")
|
||||
print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}")
|
||||
# - Load train an val datasets
|
||||
train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps)
|
||||
val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps)
|
||||
print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}")
|
||||
print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}")
|
||||
|
||||
|
||||
222
examples/port_datasets/pusht_zarr.py
Normal file
@@ -0,0 +1,222 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
PUSHT_TASK = "Push the T-shaped blue block onto the T-shaped green target surface."
|
||||
PUSHT_FEATURES = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {
|
||||
"axes": ["x", "y"],
|
||||
},
|
||||
},
|
||||
"next.reward": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"next.success": {
|
||||
"dtype": "bool",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
},
|
||||
"observation.environment_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (16,),
|
||||
"names": [
|
||||
"keypoints",
|
||||
],
|
||||
},
|
||||
"observation.image": {
|
||||
"dtype": None,
|
||||
"shape": (3, 96, 96),
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def build_features(mode: str) -> dict:
|
||||
features = PUSHT_FEATURES
|
||||
if mode == "keypoints":
|
||||
features.pop("observation.image")
|
||||
else:
|
||||
features.pop("observation.environment_state")
|
||||
features["observation.image"]["dtype"] = mode
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def load_raw_dataset(zarr_path: Path, load_images: bool = True):
|
||||
try:
|
||||
from lerobot.common.datasets.push_dataset_to_hub._diffusion_policy_replay_buffer import (
|
||||
ReplayBuffer as DiffusionPolicyReplayBuffer,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path)
|
||||
return zarr_data
|
||||
|
||||
|
||||
def calculate_coverage(zarr_data):
|
||||
try:
|
||||
import pymunk
|
||||
from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely
|
||||
except ModuleNotFoundError as e:
|
||||
print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`")
|
||||
raise e
|
||||
|
||||
block_pos = zarr_data["state"][:, 2:4]
|
||||
block_angle = zarr_data["state"][:, 4]
|
||||
|
||||
num_frames = len(block_pos)
|
||||
|
||||
coverage = np.zeros((num_frames,))
|
||||
# 8 keypoints with 2 coords each
|
||||
keypoints = np.zeros((num_frames, 16))
|
||||
|
||||
# Set x, y, theta (in radians)
|
||||
goal_pos_angle = np.array([256, 256, np.pi / 4])
|
||||
goal_body = PushTEnv.get_goal_pose_body(goal_pos_angle)
|
||||
|
||||
for i in range(num_frames):
|
||||
space = pymunk.Space()
|
||||
space.gravity = 0, 0
|
||||
space.damping = 0
|
||||
|
||||
# Add walls.
|
||||
walls = [
|
||||
PushTEnv.add_segment(space, (5, 506), (5, 5), 2),
|
||||
PushTEnv.add_segment(space, (5, 5), (506, 5), 2),
|
||||
PushTEnv.add_segment(space, (506, 5), (506, 506), 2),
|
||||
PushTEnv.add_segment(space, (5, 506), (506, 506), 2),
|
||||
]
|
||||
space.add(*walls)
|
||||
|
||||
block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item())
|
||||
goal_geom = pymunk_to_shapely(goal_body, block_body.shapes)
|
||||
block_geom = pymunk_to_shapely(block_body, block_body.shapes)
|
||||
intersection_area = goal_geom.intersection(block_geom).area
|
||||
goal_area = goal_geom.area
|
||||
coverage[i] = intersection_area / goal_area
|
||||
keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten())
|
||||
|
||||
return coverage, keypoints
|
||||
|
||||
|
||||
def calculate_success(coverage: float, success_threshold: float):
|
||||
return coverage > success_threshold
|
||||
|
||||
|
||||
def calculate_reward(coverage: float, success_threshold: float):
|
||||
return np.clip(coverage / success_threshold, 0, 1)
|
||||
|
||||
|
||||
def main(raw_dir: Path, repo_id: str, mode: str = "video", push_to_hub: bool = True):
|
||||
if mode not in ["video", "image", "keypoints"]:
|
||||
raise ValueError(mode)
|
||||
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id="lerobot-raw/pusht_raw")
|
||||
|
||||
zarr_data = load_raw_dataset(zarr_path=raw_dir / "pusht_cchi_v7_replay.zarr")
|
||||
|
||||
env_state = zarr_data["state"][:]
|
||||
agent_pos = env_state[:, :2]
|
||||
|
||||
action = zarr_data["action"][:]
|
||||
image = zarr_data["img"] # (b, h, w, c)
|
||||
|
||||
episode_data_index = {
|
||||
"from": np.concatenate(([0], zarr_data.meta["episode_ends"][:-1])),
|
||||
"to": zarr_data.meta["episode_ends"],
|
||||
}
|
||||
|
||||
# Calculate success and reward based on the overlapping area
|
||||
# of the T-object and the T-area.
|
||||
coverage, keypoints = calculate_coverage(zarr_data)
|
||||
success = calculate_success(coverage, success_threshold=0.95)
|
||||
reward = calculate_reward(coverage, success_threshold=0.95)
|
||||
|
||||
features = build_features(mode)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=10,
|
||||
robot_type="2d pointer",
|
||||
features=features,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
episodes = range(len(episode_data_index["from"]))
|
||||
for ep_idx in episodes:
|
||||
from_idx = episode_data_index["from"][ep_idx]
|
||||
to_idx = episode_data_index["to"][ep_idx]
|
||||
num_frames = to_idx - from_idx
|
||||
|
||||
for frame_idx in range(num_frames):
|
||||
i = from_idx + frame_idx
|
||||
frame = {
|
||||
"action": torch.from_numpy(action[i]),
|
||||
# Shift reward and success by +1 until the last item of the episode
|
||||
"next.reward": reward[i + (frame_idx < num_frames - 1)],
|
||||
"next.success": success[i + (frame_idx < num_frames - 1)],
|
||||
}
|
||||
|
||||
frame["observation.state"] = torch.from_numpy(agent_pos[i])
|
||||
|
||||
if mode == "keypoints":
|
||||
frame["observation.environment_state"] = torch.from_numpy(keypoints[i])
|
||||
else:
|
||||
frame["observation.image"] = torch.from_numpy(image[i])
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=PUSHT_TASK)
|
||||
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To try this script, modify the repo id with your own HuggingFace user (e.g cadene/pusht)
|
||||
repo_id = "lerobot/pusht"
|
||||
|
||||
modes = ["video", "image", "keypoints"]
|
||||
# Uncomment if you want to try with a specific mode
|
||||
# modes = ["video"]
|
||||
# modes = ["image"]
|
||||
# modes = ["keypoints"]
|
||||
|
||||
raw_dir = Path("data/lerobot-raw/pusht_raw")
|
||||
for mode in modes:
|
||||
if mode in ["image", "keypoints"]:
|
||||
repo_id += f"_{mode}"
|
||||
|
||||
# download and load raw dataset, create LeRobotDataset, populate it, push to hub
|
||||
main(raw_dir, repo_id=repo_id, mode=mode)
|
||||
|
||||
# Uncomment if you want to load the local dataset and explore it
|
||||
# dataset = LeRobotDataset(repo_id=repo_id, local_files_only=True)
|
||||
# breakpoint()
|
||||
27
lerobot/common/datasets/card_template.md
Normal file
@@ -0,0 +1,27 @@
|
||||
---
|
||||
# For reference on dataset card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/datasetcard.md?plain=1
|
||||
# Doc / guide: https://huggingface.co/docs/hub/datasets-cards
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
- **Homepage:** {{ url | default("[More Information Needed]", true)}}
|
||||
- **Paper:** {{ paper | default("[More Information Needed]", true)}}
|
||||
- **License:** {{ license | default("[More Information Needed]", true)}}
|
||||
|
||||
## Dataset Structure
|
||||
|
||||
{{ dataset_structure | default("[More Information Needed]", true)}}
|
||||
|
||||
## Citation
|
||||
|
||||
**BibTeX:**
|
||||
|
||||
```bibtex
|
||||
{{ citation_bibtex | default("[More Information Needed]", true)}}
|
||||
```
|
||||
@@ -19,9 +19,6 @@ from math import ceil
|
||||
import einops
|
||||
import torch
|
||||
import tqdm
|
||||
from datasets import Image
|
||||
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
@@ -39,15 +36,13 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
batch = next(iter(dataloader))
|
||||
|
||||
stats_patterns = {}
|
||||
for key, feats_type in dataset.features.items():
|
||||
# NOTE: skip language_instruction embedding in stats computation
|
||||
if key == "language_instruction":
|
||||
continue
|
||||
|
||||
for key in dataset.features:
|
||||
# sanity check that tensors are not float64
|
||||
assert batch[key].dtype != torch.float64
|
||||
|
||||
if isinstance(feats_type, (VideoFrame, Image)):
|
||||
# if isinstance(feats_type, (VideoFrame, Image)):
|
||||
if key in dataset.meta.camera_keys:
|
||||
# sanity check that images are channel first
|
||||
_, c, h, w = batch[key].shape
|
||||
assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}"
|
||||
@@ -63,7 +58,7 @@ def get_stats_einops_patterns(dataset, num_workers=0):
|
||||
elif batch[key].ndim == 1:
|
||||
stats_patterns[key] = "b -> 1"
|
||||
else:
|
||||
raise ValueError(f"{key}, {feats_type}, {batch[key].shape}")
|
||||
raise ValueError(f"{key}, {batch[key].shape}")
|
||||
|
||||
return stats_patterns
|
||||
|
||||
@@ -175,39 +170,45 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
data_keys = set()
|
||||
for dataset in ls_datasets:
|
||||
data_keys.update(dataset.stats.keys())
|
||||
data_keys.update(dataset.meta.stats.keys())
|
||||
stats = {k: {} for k in data_keys}
|
||||
for data_key in data_keys:
|
||||
for stat_key in ["min", "max"]:
|
||||
# compute `max(dataset_0["max"], dataset_1["max"], ...)`
|
||||
stats[data_key][stat_key] = einops.reduce(
|
||||
torch.stack([d.stats[data_key][stat_key] for d in ls_datasets if data_key in d.stats], dim=0),
|
||||
torch.stack(
|
||||
[ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats],
|
||||
dim=0,
|
||||
),
|
||||
"n ... -> ...",
|
||||
stat_key,
|
||||
)
|
||||
total_samples = sum(d.num_samples for d in ls_datasets if data_key in d.stats)
|
||||
total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats)
|
||||
# Compute the "sum" statistic by multiplying each mean by the number of samples in the respective
|
||||
# dataset, then divide by total_samples to get the overall "mean".
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["mean"] = sum(
|
||||
d.stats[data_key]["mean"] * (d.num_samples / total_samples)
|
||||
d.meta.stats[data_key]["mean"] * (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
# The derivation for standard deviation is a little more involved but is much in the same spirit as
|
||||
# the computation of the mean.
|
||||
# Given two sets of data where the statistics are known:
|
||||
# σ_combined = sqrt[ (n1 * (σ1^2 + d1^2) + n2 * (σ2^2 + d2^2)) / (n1 + n2) ]
|
||||
# where d1 = μ1 - μ_combined, d2 = μ2 - μ_combined
|
||||
# NOTE: the brackets around (d.num_samples / total_samples) are needed tor minimize the risk of
|
||||
# NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of
|
||||
# numerical overflow!
|
||||
stats[data_key]["std"] = torch.sqrt(
|
||||
sum(
|
||||
(d.stats[data_key]["std"] ** 2 + (d.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2)
|
||||
* (d.num_samples / total_samples)
|
||||
(
|
||||
d.meta.stats[data_key]["std"] ** 2
|
||||
+ (d.meta.stats[data_key]["mean"] - stats[data_key]["mean"]) ** 2
|
||||
)
|
||||
* (d.num_frames / total_samples)
|
||||
for d in ls_datasets
|
||||
if data_key in d.stats
|
||||
if data_key in d.meta.stats
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
@@ -111,6 +111,6 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData
|
||||
for stats_type, listconfig in stats_dict.items():
|
||||
# example of stats_type: min, max, mean, std
|
||||
stats = OmegaConf.to_container(listconfig, resolve=True)
|
||||
dataset.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -14,14 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import multiprocessing
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
import queue
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import tqdm
|
||||
from PIL import Image
|
||||
|
||||
DEFAULT_IMAGE_PATH = "{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
|
||||
def safe_stop_image_writer(func):
|
||||
@@ -39,8 +38,56 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class ImageWriter:
|
||||
"""This class abstract away the initialisation of processes or/and threads to
|
||||
def image_array_to_image(image_array: np.ndarray) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim == 3 and image_array.shape[0] in [1, 3]:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
if image_array.dtype != np.uint8:
|
||||
# Assume the image is in [0, 1] range for floating-point data
|
||||
image_array = np.clip(image_array, 0, 1)
|
||||
image_array = (image_array * 255).astype(np.uint8)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
try:
|
||||
if isinstance(image, np.ndarray):
|
||||
img = image_array_to_image(image)
|
||||
elif isinstance(image, PIL.Image.Image):
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath)
|
||||
except Exception as e:
|
||||
print(f"Error writing image {fpath}: {e}")
|
||||
|
||||
|
||||
def worker_thread_loop(queue: queue.Queue):
|
||||
while True:
|
||||
item = queue.get()
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
def worker_process(queue: queue.Queue, num_threads: int):
|
||||
threads = []
|
||||
for _ in range(num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
threads.append(t)
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
|
||||
class AsyncImageWriter:
|
||||
"""
|
||||
This class abstract away the initialisation of processes or/and threads to
|
||||
save images on disk asynchrounously, which is critical to control a robot and record data
|
||||
at a high frame rate.
|
||||
|
||||
@@ -53,82 +100,61 @@ class ImageWriter:
|
||||
the number of threads. If it is still not stable, try to use 1 subprocess, or more.
|
||||
"""
|
||||
|
||||
def __init__(self, write_dir: Path, num_processes: int = 0, num_threads: int = 1):
|
||||
self.dir = write_dir
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.image_path = DEFAULT_IMAGE_PATH
|
||||
def __init__(self, num_processes: int = 0, num_threads: int = 1):
|
||||
self.num_processes = num_processes
|
||||
self.num_threads = self.num_threads_per_process = num_threads
|
||||
self.num_threads = num_threads
|
||||
self.queue = None
|
||||
self.threads = []
|
||||
self.processes = []
|
||||
self._stopped = False
|
||||
|
||||
if self.num_processes <= 0:
|
||||
self.type = "threads"
|
||||
self.threads = ThreadPoolExecutor(max_workers=self.num_threads)
|
||||
self.futures = []
|
||||
if num_threads <= 0 and num_processes <= 0:
|
||||
raise ValueError("Number of threads and processes must be greater than zero.")
|
||||
|
||||
if self.num_processes == 0:
|
||||
# Use threading
|
||||
self.queue = queue.Queue()
|
||||
for _ in range(self.num_threads):
|
||||
t = threading.Thread(target=worker_thread_loop, args=(self.queue,))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
self.threads.append(t)
|
||||
else:
|
||||
self.type = "processes"
|
||||
self.num_threads_per_process = self.num_threads
|
||||
self.image_queue = multiprocessing.Queue()
|
||||
self.processes: list[multiprocessing.Process] = []
|
||||
for _ in range(num_processes):
|
||||
process = multiprocessing.Process(target=self._loop_to_save_images_in_threads)
|
||||
process.start()
|
||||
self.processes.append(process)
|
||||
# Use multiprocessing
|
||||
self.queue = multiprocessing.JoinableQueue()
|
||||
for _ in range(self.num_processes):
|
||||
p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads))
|
||||
p.daemon = True
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def _loop_to_save_images_in_threads(self) -> None:
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
futures = []
|
||||
while True:
|
||||
frame_data = self.image_queue.get()
|
||||
if frame_data is None:
|
||||
break
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
|
||||
image, file_path = frame_data
|
||||
futures.append(executor.submit(self._save_image, image, file_path))
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
with tqdm.tqdm(total=len(futures), desc="Writing images") as progress_bar:
|
||||
wait(futures)
|
||||
progress_bar.update(len(futures))
|
||||
def stop(self):
|
||||
if self._stopped:
|
||||
return
|
||||
|
||||
def async_save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
"""Save an image asynchronously using threads or processes."""
|
||||
if self.type == "threads":
|
||||
self.futures.append(self.threads.submit(self._save_image, image, file_path))
|
||||
if self.num_processes == 0:
|
||||
for _ in self.threads:
|
||||
self.queue.put(None)
|
||||
for t in self.threads:
|
||||
t.join()
|
||||
else:
|
||||
self.image_queue.put((image, file_path))
|
||||
num_nones = self.num_processes * self.num_threads
|
||||
for _ in range(num_nones):
|
||||
self.queue.put(None)
|
||||
for p in self.processes:
|
||||
p.join()
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
self.queue.close()
|
||||
self.queue.join_thread()
|
||||
|
||||
def _save_image(self, image: torch.Tensor, file_path: Path) -> None:
|
||||
img = Image.fromarray(image.numpy())
|
||||
img.save(str(file_path), quality=100)
|
||||
|
||||
def get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = self.image_path.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self.dir / fpath
|
||||
|
||||
def get_episode_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self.get_image_file_path(
|
||||
episode_index=episode_index, image_key=image_key, frame_index=0
|
||||
).parent
|
||||
|
||||
def stop(self, timeout=20) -> None:
|
||||
"""Stop the image writer, waiting for all processes or threads to finish."""
|
||||
if self.type == "threads":
|
||||
with tqdm.tqdm(total=len(self.futures), desc="Writing images") as progress_bar:
|
||||
wait(self.futures, timeout=timeout)
|
||||
progress_bar.update(len(self.futures))
|
||||
else:
|
||||
self._stop_processes(timeout)
|
||||
|
||||
def _stop_processes(self, timeout) -> None:
|
||||
for _ in self.processes:
|
||||
self.image_queue.put(None)
|
||||
|
||||
for process in self.processes:
|
||||
process.join(timeout=timeout)
|
||||
|
||||
if process.is_alive():
|
||||
process.terminate()
|
||||
|
||||
self.image_queue.close()
|
||||
self.image_queue.join_thread()
|
||||
self._stopped = True
|
||||
|
||||
@@ -187,7 +187,7 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
assert data[OnlineBuffer.INDEX_KEY][0].item() == 0
|
||||
|
||||
# Shift the incoming indices if necessary.
|
||||
if self.num_samples > 0:
|
||||
if self.num_frames > 0:
|
||||
last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1]
|
||||
last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1]
|
||||
data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1
|
||||
@@ -227,11 +227,11 @@ class OnlineBuffer(torch.utils.data.Dataset):
|
||||
)
|
||||
|
||||
@property
|
||||
def num_samples(self) -> int:
|
||||
def num_frames(self) -> int:
|
||||
return np.count_nonzero(self._data[OnlineBuffer.OCCUPANCY_MASK_KEY])
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
return self.num_frames
|
||||
|
||||
def _item_to_tensors(self, item: dict) -> dict:
|
||||
item_ = {}
|
||||
|
||||
@@ -30,12 +30,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -24,8 +24,11 @@ from datasets import Dataset, Features, Image, Value
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
|
||||
from lerobot.common.datasets.utils import calculate_episode_data_index, hf_transform_to_torch
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
)
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
|
||||
|
||||
@@ -26,8 +26,8 @@ import torch
|
||||
from datasets import Dataset, Features, Image, Sequence, Value
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame
|
||||
|
||||
@@ -42,12 +42,12 @@ from PIL import Image as PILImage
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.openx.transforms import OPENX_STANDARDIZATION_TRANSFORMS
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -28,12 +28,12 @@ from PIL import Image as PILImage
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -16,7 +16,9 @@
|
||||
import inspect
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy
|
||||
import PIL
|
||||
import torch
|
||||
@@ -72,3 +74,58 @@ def check_repo_id(repo_id: str) -> None:
|
||||
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
|
||||
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
@@ -27,12 +27,12 @@ from PIL import Image as PILImage
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import (
|
||||
calculate_episode_data_index,
|
||||
concatenate_episodes,
|
||||
get_default_encoding,
|
||||
save_images_concurrently,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame, encode_video_frames
|
||||
|
||||
@@ -14,16 +14,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import warnings
|
||||
import logging
|
||||
import textwrap
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Dict
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import pyarrow.compute as pc
|
||||
import torch
|
||||
from huggingface_hub import DatasetCard, HfApi
|
||||
from datasets.table import embed_table_storage
|
||||
from huggingface_hub import DatasetCard, DatasetCardData, HfApi
|
||||
from PIL import Image as PILImage
|
||||
from torchvision import transforms
|
||||
|
||||
@@ -38,6 +42,7 @@ TASKS_PATH = "meta/tasks.jsonl"
|
||||
|
||||
DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png"
|
||||
|
||||
DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
@@ -45,8 +50,18 @@ DATASET_CARD_TEMPLATE = """
|
||||
---
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
## {}
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"frame_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"episode_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
"task_index": {"dtype": "int64", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
|
||||
@@ -80,22 +95,80 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
return outdict
|
||||
|
||||
|
||||
def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(serialized_dict)
|
||||
|
||||
|
||||
def write_parquet(dataset: datasets.Dataset, fpath: Path) -> None:
|
||||
# Embed image bytes into the table before saving to parquet
|
||||
format = dataset.format
|
||||
dataset = dataset.with_format("arrow")
|
||||
dataset = dataset.map(embed_table_storage, batched=False)
|
||||
dataset = dataset.with_format(**format)
|
||||
dataset.to_parquet(fpath)
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> Any:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def append_jsonl(data: dict, fpath: Path) -> None:
|
||||
def load_jsonlines(fpath: Path) -> list[Any]:
|
||||
with jsonlines.open(fpath, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def append_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "a") as writer:
|
||||
writer.write(data)
|
||||
|
||||
|
||||
def write_stats(stats: dict[str, torch.Tensor | dict], fpath: Path) -> None:
|
||||
serialized_stats = {key: value.tolist() for key, value in flatten_dict(stats).items()}
|
||||
serialized_stats = unflatten_dict(serialized_stats)
|
||||
write_json(serialized_stats, fpath)
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
info = load_json(local_dir / INFO_PATH)
|
||||
for ft in info["features"].values():
|
||||
ft["shape"] = tuple(ft["shape"])
|
||||
return info
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
if not (local_dir / STATS_PATH).exists():
|
||||
return None
|
||||
stats = load_json(local_dir / STATS_PATH)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
tasks = load_jsonlines(local_dir / TASKS_PATH)
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episodes(local_dir: Path) -> dict:
|
||||
return load_jsonlines(local_dir / EPISODES_PATH)
|
||||
|
||||
|
||||
def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray:
|
||||
img = PILImage.open(fpath).convert("RGB")
|
||||
img_array = np.array(img, dtype=dtype)
|
||||
if channel_first: # (H, W, C) -> (C, H, W)
|
||||
img_array = np.transpose(img_array, (2, 0, 1))
|
||||
if "float" in dtype:
|
||||
img_array /= 255.0
|
||||
return img_array
|
||||
|
||||
|
||||
def hf_transform_to_torch(items_dict: dict[torch.Tensor | None]):
|
||||
@@ -121,46 +194,62 @@ def _get_major_minor(version: str) -> tuple[int]:
|
||||
return int(split[0]), int(split[1])
|
||||
|
||||
|
||||
class BackwardCompatibilityError(Exception):
|
||||
def __init__(self, repo_id, version):
|
||||
message = textwrap.dedent(f"""
|
||||
BackwardCompatibilityError: The dataset you requested ({repo_id}) is in {version} format.
|
||||
|
||||
We introduced a new format since v2.0 which is not backward compatible with v1.x.
|
||||
Please, use our conversion script. Modify the following command with your own task description:
|
||||
```
|
||||
python lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py \\
|
||||
--repo-id {repo_id} \\
|
||||
--single-task "TASK DESCRIPTION." # <---- /!\\ Replace TASK DESCRIPTION /!\\
|
||||
```
|
||||
|
||||
A few examples to replace TASK DESCRIPTION: "Pick up the blue cube and place it into the bin.",
|
||||
"Insert the peg into the socket.", "Slide open the ziploc bag.", "Take the elevator to the 1st floor.",
|
||||
"Open the top cabinet, store the pot inside it then close the cabinet.", "Push the T-shaped block onto the T-shaped target.",
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog.", "Fold the sweatshirt.", ...
|
||||
|
||||
If you encounter a problem, contact LeRobot maintainers on [Discord](https://discord.com/invite/s3KuuzsPFb)
|
||||
or open an [issue on GitHub](https://github.com/huggingface/lerobot/issues/new/choose).
|
||||
""")
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def check_version_compatibility(
|
||||
repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True
|
||||
) -> None:
|
||||
current_major, _ = _get_major_minor(current_version)
|
||||
major_to_check, _ = _get_major_minor(version_to_check)
|
||||
if major_to_check < current_major and enforce_breaking_major:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version_to_check} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
raise BackwardCompatibilityError(repo_id, version_to_check)
|
||||
elif float(version_to_check.strip("v")) < float(current_version.strip("v")):
|
||||
warnings.warn(
|
||||
logging.warning(
|
||||
f"""The dataset you requested ({repo_id}) was created with a previous version ({version_to_check}) of the
|
||||
codebase. The current codebase version is {current_version}. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) -> str:
|
||||
num_version = float(version.strip("v"))
|
||||
if num_version < 2 and enforce_v2:
|
||||
raise ValueError(
|
||||
f"""The dataset you requested ({repo_id}) is in {version} format. We introduced a new
|
||||
format with v2.0 that is not backward compatible. Please use our conversion script
|
||||
first (convert_dataset_v1_to_v2.py) to convert your dataset to this new format."""
|
||||
)
|
||||
def get_hub_safe_version(repo_id: str, version: str) -> str:
|
||||
api = HfApi()
|
||||
dataset_info = api.list_repo_refs(repo_id, repo_type="dataset")
|
||||
branches = [b.name for b in dataset_info.branches]
|
||||
if version not in branches:
|
||||
warnings.warn(
|
||||
num_version = float(version.strip("v"))
|
||||
hub_num_versions = [float(v.strip("v")) for v in branches if v.startswith("v")]
|
||||
if num_version >= 2.0 and all(v < 2.0 for v in hub_num_versions):
|
||||
raise BackwardCompatibilityError(repo_id, version)
|
||||
|
||||
logging.warning(
|
||||
f"""You are trying to load a dataset from {repo_id} created with a previous version of the
|
||||
codebase. The following versions are available: {branches}.
|
||||
The requested version ('{version}') is not found. You should be fine since
|
||||
backward compatibility is maintained. If you encounter a problem, contact LeRobot maintainers on
|
||||
Discord ('https://discord.com/invite/s3KuuzsPFb') or open an issue on github.""",
|
||||
stacklevel=1,
|
||||
)
|
||||
if "main" not in branches:
|
||||
raise ValueError(f"Version 'main' not found on {repo_id}")
|
||||
@@ -169,63 +258,43 @@ def get_hub_safe_version(repo_id: str, version: str, enforce_v2: bool = True) ->
|
||||
return version
|
||||
|
||||
|
||||
def load_info(local_dir: Path) -> dict:
|
||||
with open(local_dir / INFO_PATH) as f:
|
||||
return json.load(f)
|
||||
def get_hf_features_from_features(features: dict) -> datasets.Features:
|
||||
hf_features = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "video":
|
||||
continue
|
||||
elif ft["dtype"] == "image":
|
||||
hf_features[key] = datasets.Image()
|
||||
elif ft["shape"] == (1,):
|
||||
hf_features[key] = datasets.Value(dtype=ft["dtype"])
|
||||
else:
|
||||
assert len(ft["shape"]) == 1
|
||||
hf_features[key] = datasets.Sequence(
|
||||
length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"])
|
||||
)
|
||||
|
||||
return datasets.Features(hf_features)
|
||||
|
||||
|
||||
def load_stats(local_dir: Path) -> dict:
|
||||
with open(local_dir / STATS_PATH) as f:
|
||||
stats = json.load(f)
|
||||
stats = {key: torch.tensor(value) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
def load_tasks(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / TASKS_PATH, "r") as reader:
|
||||
tasks = list(reader)
|
||||
|
||||
return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])}
|
||||
|
||||
|
||||
def load_episode_dicts(local_dir: Path) -> dict:
|
||||
with jsonlines.open(local_dir / EPISODES_PATH, "r") as reader:
|
||||
return list(reader)
|
||||
|
||||
|
||||
def _get_info_from_robot(robot: Robot, use_videos: bool) -> tuple[list | dict]:
|
||||
shapes = {key: len(names) for key, names in robot.names.items()}
|
||||
camera_shapes = {}
|
||||
for key, cam in robot.cameras.items():
|
||||
video_key = f"observation.images.{key}"
|
||||
camera_shapes[video_key] = {
|
||||
"width": cam.width,
|
||||
"height": cam.height,
|
||||
"channels": cam.channels,
|
||||
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
|
||||
camera_ft = {}
|
||||
if robot.cameras:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video" if use_videos else "image", **ft}
|
||||
for key, ft in robot.camera_features.items()
|
||||
}
|
||||
keys = list(robot.names)
|
||||
image_keys = [] if use_videos else list(camera_shapes)
|
||||
video_keys = list(camera_shapes) if use_videos else []
|
||||
shapes = {**shapes, **camera_shapes}
|
||||
names = robot.names
|
||||
robot_type = robot.robot_type
|
||||
|
||||
return robot_type, keys, image_keys, video_keys, shapes, names
|
||||
return {**robot.motor_features, **camera_ft, **DEFAULT_FEATURES}
|
||||
|
||||
|
||||
def create_empty_dataset_info(
|
||||
codebase_version: str,
|
||||
fps: int,
|
||||
robot_type: str,
|
||||
keys: list[str],
|
||||
image_keys: list[str],
|
||||
video_keys: list[str],
|
||||
shapes: dict,
|
||||
names: dict,
|
||||
features: dict,
|
||||
use_videos: bool,
|
||||
) -> dict:
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": 0,
|
||||
"total_frames": 0,
|
||||
@@ -235,16 +304,15 @@ def create_empty_dataset_info(
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"keys": keys,
|
||||
"video_keys": video_keys,
|
||||
"image_keys": image_keys,
|
||||
"shapes": shapes,
|
||||
"names": names,
|
||||
"videos": {"videos_path": DEFAULT_VIDEO_PATH} if len(video_keys) > 0 else None,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
|
||||
def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[str, torch.Tensor]:
|
||||
def get_episode_data_index(
|
||||
episode_dicts: list[dict], episodes: list[int] | None = None
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)}
|
||||
if episodes is not None:
|
||||
episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes}
|
||||
@@ -256,6 +324,31 @@ def get_episode_data_index(episodes: list, episode_dicts: list[dict]) -> dict[st
|
||||
}
|
||||
|
||||
|
||||
def calculate_total_episode(
|
||||
hf_dataset: datasets.Dataset, raise_if_not_contiguous: bool = True
|
||||
) -> dict[str, torch.Tensor]:
|
||||
episode_indices = sorted(hf_dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
if raise_if_not_contiguous and episode_indices != list(range(total_episodes)):
|
||||
raise ValueError("episode_index values are not sorted and contiguous.")
|
||||
return total_episodes
|
||||
|
||||
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
|
||||
episode_lengths = []
|
||||
table = hf_dataset.data.table
|
||||
total_episodes = calculate_total_episode(hf_dataset)
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
episode_lengths.insert(ep_idx, len(ep_table))
|
||||
|
||||
cumulative_lenghts = list(accumulate(episode_lengths))
|
||||
return {
|
||||
"from": torch.LongTensor([0] + cumulative_lenghts[:-1]),
|
||||
"to": torch.LongTensor(cumulative_lenghts),
|
||||
}
|
||||
|
||||
|
||||
def check_timestamps_sync(
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
@@ -268,13 +361,11 @@ def check_timestamps_sync(
|
||||
account for possible numerical error.
|
||||
"""
|
||||
timestamps = torch.stack(hf_dataset["timestamp"])
|
||||
# timestamps[2] += tolerance_s # TODO delete
|
||||
# timestamps[-2] += tolerance_s/2 # TODO delete
|
||||
diffs = torch.diff(timestamps)
|
||||
within_tolerance = torch.abs(diffs - 1 / fps) <= tolerance_s
|
||||
|
||||
# We mask differences between the timestamp at the end of an episode
|
||||
# and the one the start of the next episode since these are expected
|
||||
# and the one at the start of the next episode since these are expected
|
||||
# to be outside tolerance.
|
||||
mask = torch.ones(len(diffs), dtype=torch.bool)
|
||||
ignored_diffs = episode_data_index["to"][:-1] - 1
|
||||
@@ -318,7 +409,7 @@ def check_delta_timestamps(
|
||||
"""
|
||||
outside_tolerance = {}
|
||||
for key, delta_ts in delta_timestamps.items():
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) <= tolerance_s for ts in delta_ts]
|
||||
within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts]
|
||||
if not all(within_tolerance):
|
||||
outside_tolerance[key] = [
|
||||
ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within
|
||||
@@ -347,180 +438,6 @@ def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dic
|
||||
return delta_indices
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def load_previous_and_future_frames(
|
||||
item: dict[str, torch.Tensor],
|
||||
hf_dataset: datasets.Dataset,
|
||||
episode_data_index: dict[str, torch.Tensor],
|
||||
delta_timestamps: dict[str, list[float]],
|
||||
tolerance_s: float,
|
||||
) -> dict[torch.Tensor]:
|
||||
"""
|
||||
Given a current item in the dataset containing a timestamp (e.g. 0.6 seconds), and a list of time differences of
|
||||
some modalities (e.g. delta_timestamps={"observation.image": [-0.8, -0.2, 0, 0.2]}), this function computes for each
|
||||
given modality (e.g. "observation.image") a list of query timestamps (e.g. [-0.2, 0.4, 0.6, 0.8]) and loads the closest
|
||||
frames in the dataset.
|
||||
|
||||
Importantly, when no frame can be found around a query timestamp within a specified tolerance window, this function
|
||||
raises an AssertionError. When a timestamp is queried before the first available timestamp of the episode or after
|
||||
the last available timestamp, the violation of the tolerance doesnt raise an AssertionError, and the function
|
||||
populates a boolean array indicating which frames are outside of the episode range. For instance, this boolean array
|
||||
is useful during batched training to not supervise actions associated to timestamps coming after the end of the
|
||||
episode, or to pad the observations in a specific way. Note that by default the observation frames before the start
|
||||
of the episode are the same as the first frame of the episode.
|
||||
|
||||
Parameters:
|
||||
- item (dict): A dictionary containing all the data related to a frame. It is the result of `dataset[idx]`. Each key
|
||||
corresponds to a different modality (e.g., "timestamp", "observation.image", "action").
|
||||
- hf_dataset (datasets.Dataset): A dictionary containing the full dataset. Each key corresponds to a different
|
||||
modality (e.g., "timestamp", "observation.image", "action").
|
||||
- episode_data_index (dict): A dictionary containing two keys ("from" and "to") associated to dataset indices.
|
||||
They indicate the start index and end index of each episode in the dataset.
|
||||
- delta_timestamps (dict): A dictionary containing lists of delta timestamps for each possible modality to be
|
||||
retrieved. These deltas are added to the item timestamp to form the query timestamps.
|
||||
- tolerance_s (float, optional): The tolerance level (in seconds) used to determine if a data point is close enough to the query
|
||||
timestamp by asserting `tol > difference`. It is suggested to set `tol` to a smaller value than the
|
||||
smallest expected inter-frame period, but large enough to account for jitter.
|
||||
|
||||
Returns:
|
||||
- The same item with the queried frames for each modality specified in delta_timestamps, with an additional key for
|
||||
each modality (e.g. "observation.image_is_pad").
|
||||
|
||||
Raises:
|
||||
- AssertionError: If any of the frames unexpectedly violate the tolerance level. This could indicate synchronization
|
||||
issues with timestamps during data collection.
|
||||
"""
|
||||
# get indices of the frames associated to the episode, and their timestamps
|
||||
ep_id = item["episode_index"].item()
|
||||
ep_data_id_from = episode_data_index["from"][ep_id].item()
|
||||
ep_data_id_to = episode_data_index["to"][ep_id].item()
|
||||
ep_data_ids = torch.arange(ep_data_id_from, ep_data_id_to, 1)
|
||||
|
||||
# load timestamps
|
||||
ep_timestamps = hf_dataset.select_columns("timestamp")[ep_data_id_from:ep_data_id_to]["timestamp"]
|
||||
ep_timestamps = torch.stack(ep_timestamps)
|
||||
|
||||
# we make the assumption that the timestamps are sorted
|
||||
ep_first_ts = ep_timestamps[0]
|
||||
ep_last_ts = ep_timestamps[-1]
|
||||
current_ts = item["timestamp"].item()
|
||||
|
||||
for key in delta_timestamps:
|
||||
# get timestamps used as query to retrieve data of previous/future frames
|
||||
delta_ts = delta_timestamps[key]
|
||||
query_ts = current_ts + torch.tensor(delta_ts)
|
||||
|
||||
# compute distances between each query timestamp and all timestamps of all the frames belonging to the episode
|
||||
dist = torch.cdist(query_ts[:, None], ep_timestamps[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
# TODO(rcadene): synchronize timestamps + interpolation if needed
|
||||
|
||||
is_pad = min_ > tolerance_s
|
||||
|
||||
# check violated query timestamps are all outside the episode range
|
||||
assert ((query_ts[is_pad] < ep_first_ts) | (ep_last_ts < query_ts[is_pad])).all(), (
|
||||
f"One or several timestamps unexpectedly violate the tolerance ({min_} > {tolerance_s=}) inside episode range."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
)
|
||||
|
||||
# get dataset indices corresponding to frames to be loaded
|
||||
data_ids = ep_data_ids[argmin_]
|
||||
|
||||
# load frames modality
|
||||
item[key] = hf_dataset.select_columns(key)[data_ids][key]
|
||||
|
||||
if isinstance(item[key][0], dict) and "path" in item[key][0]:
|
||||
# video mode where frame are expressed as dict of path and timestamp
|
||||
item[key] = item[key]
|
||||
else:
|
||||
item[key] = torch.stack(item[key])
|
||||
|
||||
item[f"{key}_is_pad"] = is_pad
|
||||
|
||||
return item
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset.
|
||||
|
||||
Parameters:
|
||||
- hf_dataset (datasets.Dataset): A HuggingFace dataset containing the episode index.
|
||||
|
||||
Returns:
|
||||
- episode_data_index: A dictionary containing the data index for each episode. The dictionary has two keys:
|
||||
- "from": A tensor containing the starting index of each episode.
|
||||
- "to": A tensor containing the ending index of each episode.
|
||||
"""
|
||||
episode_data_index = {"from": [], "to": []}
|
||||
|
||||
current_episode = None
|
||||
"""
|
||||
The episode_index is a list of integers, each representing the episode index of the corresponding example.
|
||||
For instance, the following is a valid episode_index:
|
||||
[0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]
|
||||
|
||||
Below, we iterate through the episode_index and populate the episode_data_index dictionary with the starting and
|
||||
ending index of each episode. For the episode_index above, the episode_data_index dictionary will look like this:
|
||||
{
|
||||
"from": [0, 3, 7],
|
||||
"to": [3, 7, 12]
|
||||
}
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([]),
|
||||
"to": torch.tensor([]),
|
||||
}
|
||||
return episode_data_index
|
||||
for idx, episode_idx in enumerate(hf_dataset["episode_index"]):
|
||||
if episode_idx != current_episode:
|
||||
# We encountered a new episode, so we append its starting location to the "from" list
|
||||
episode_data_index["from"].append(idx)
|
||||
# If this is not the first episode, we append the ending location of the previous episode to the "to" list
|
||||
if current_episode is not None:
|
||||
episode_data_index["to"].append(idx)
|
||||
# Let's keep track of the current episode index
|
||||
current_episode = episode_idx
|
||||
else:
|
||||
# We are still in the same episode, so there is nothing for us to do here
|
||||
pass
|
||||
# We have reached the end of the dataset, so we append the ending location of the last episode to the "to" list
|
||||
episode_data_index["to"].append(idx + 1)
|
||||
|
||||
for k in ["from", "to"]:
|
||||
episode_data_index[k] = torch.tensor(episode_data_index[k])
|
||||
|
||||
return episode_data_index
|
||||
|
||||
|
||||
# TODO(aliberts): remove
|
||||
def reset_episode_index(hf_dataset: datasets.Dataset) -> datasets.Dataset:
|
||||
"""Reset the `episode_index` of the provided HuggingFace Dataset.
|
||||
|
||||
`episode_data_index` (and related functionality such as `load_previous_and_future_frames`) requires the
|
||||
`episode_index` to be sorted, continuous (1,1,1 and not 1,2,1) and start at 0.
|
||||
|
||||
This brings the `episode_index` to the required format.
|
||||
"""
|
||||
if len(hf_dataset) == 0:
|
||||
return hf_dataset
|
||||
unique_episode_idxs = torch.stack(hf_dataset["episode_index"]).unique().tolist()
|
||||
episode_idx_to_reset_idx_mapping = {
|
||||
ep_id: reset_ep_id for reset_ep_id, ep_id in enumerate(unique_episode_idxs)
|
||||
}
|
||||
|
||||
def modify_ep_idx_func(example):
|
||||
example["episode_index"] = episode_idx_to_reset_idx_mapping[example["episode_index"].item()]
|
||||
return example
|
||||
|
||||
hf_dataset = hf_dataset.map(modify_ep_idx_func)
|
||||
|
||||
return hf_dataset
|
||||
|
||||
|
||||
def cycle(iterable):
|
||||
"""The equivalent of itertools.cycle, but safe for Pytorch dataloaders.
|
||||
|
||||
@@ -550,16 +467,34 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
tags: list | None = None, text: str | None = None, info: dict | None = None
|
||||
tags: list | None = None,
|
||||
dataset_info: dict | None = None,
|
||||
**kwargs,
|
||||
) -> DatasetCard:
|
||||
card = DatasetCard(DATASET_CARD_TEMPLATE)
|
||||
card.data.task_categories = ["robotics"]
|
||||
card.data.tags = ["LeRobot"]
|
||||
if tags is not None:
|
||||
card.data.tags += tags
|
||||
if text is not None:
|
||||
card.text += f"{text}\n"
|
||||
if info is not None:
|
||||
card.text += "[meta/info.json](meta/info.json)\n"
|
||||
card.text += f"```json\n{json.dumps(info, indent=4)}\n```"
|
||||
return card
|
||||
"""
|
||||
Keyword arguments will be used to replace values in ./lerobot/common/datasets/card_template.md.
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
if tags:
|
||||
card_tags += tags
|
||||
if dataset_info:
|
||||
dataset_structure = "[meta/info.json](meta/info.json):\n"
|
||||
dataset_structure += f"```json\n{json.dumps(dataset_info, indent=4)}\n```\n"
|
||||
kwargs = {**kwargs, "dataset_structure": dataset_structure}
|
||||
card_data = DatasetCardData(
|
||||
license=kwargs.get("license"),
|
||||
tags=card_tags,
|
||||
task_categories=["robotics"],
|
||||
configs=[
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
],
|
||||
)
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_path="./lerobot/common/datasets/card_template.md",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -14,84 +14,860 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This script is for internal use to convert all datasets under the 'lerobot' hub user account to v2.
|
||||
|
||||
Note: Since the original Aloha datasets don't use shadow motors, you need to comment those out in
|
||||
lerobot/configs/robot/aloha.yaml before running this script.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config
|
||||
|
||||
LOCAL_DIR = Path("data/")
|
||||
ALOHA_SINGLE_TASKS_REAL = {
|
||||
"aloha_mobile_cabinet": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
"aloha_mobile_chair": "Push the chairs in front of the desk to place them against it.",
|
||||
"aloha_mobile_elevator": "Take the elevator to the 1st floor.",
|
||||
"aloha_mobile_shrimp": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
"aloha_mobile_wash_pan": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
"aloha_mobile_wipe_wine": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
"aloha_static_battery": "Place the battery into the slot of the remote controller.",
|
||||
"aloha_static_candy": "Pick up the candy and unwrap it.",
|
||||
"aloha_static_coffee": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
"aloha_static_coffee_new": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
"aloha_static_cups_open": "Pick up the plastic cup and open its lid.",
|
||||
"aloha_static_fork_pick_up": "Pick up the fork and place it on the plate.",
|
||||
"aloha_static_pingpong_test": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
"aloha_static_pro_pencil": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
"aloha_static_screw_driver": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
"aloha_static_tape": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
"aloha_static_thread_velcro": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
"aloha_static_towel": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
"aloha_static_vinh_cup": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
"aloha_static_vinh_cup_left": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
"aloha_static_ziploc_slide": "Slide open the ziploc bag.",
|
||||
}
|
||||
|
||||
ALOHA_CONFIG = Path("lerobot/configs/robot/aloha.yaml")
|
||||
ALOHA_MOBILE_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://mobile-aloha.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.02117",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{fu2024mobile,
|
||||
author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea},
|
||||
title = {Mobile ALOHA: Learning Bimanual Mobile Manipulation with Low-Cost Whole-Body Teleoperation},
|
||||
booktitle = {arXiv},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
ALOHA_STATIC_INFO = {
|
||||
"robot_config": parse_robot_config(ALOHA_CONFIG),
|
||||
"license": "mit",
|
||||
"url": "https://tonyzhaozh.github.io/aloha/",
|
||||
"paper": "https://arxiv.org/abs/2304.13705",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Zhao2023LearningFB,
|
||||
title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
|
||||
author={Tony Zhao and Vikash Kumar and Sergey Levine and Chelsea Finn},
|
||||
journal={RSS},
|
||||
year={2023},
|
||||
volume={abs/2304.13705},
|
||||
url={https://arxiv.org/abs/2304.13705}
|
||||
}""").lstrip(),
|
||||
}
|
||||
PUSHT_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
}
|
||||
XARM_INFO = {
|
||||
"license": "mit",
|
||||
"url": "https://www.nicklashansen.com/td-mpc/",
|
||||
"paper": "https://arxiv.org/abs/2203.04955",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
"""),
|
||||
}
|
||||
UNITREEH_INFO = {
|
||||
"license": "apache-2.0",
|
||||
}
|
||||
|
||||
DATASETS = {
|
||||
"aloha_mobile_cabinet": {
|
||||
"single_task": "Open the top cabinet, store the pot inside it then close the cabinet.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_chair": {
|
||||
"single_task": "Push the chairs in front of the desk to place them against it.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_elevator": {
|
||||
"single_task": "Take the elevator to the 1st floor.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_shrimp": {
|
||||
"single_task": "Sauté the raw shrimp on both sides, then serve it in the bowl.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wash_pan": {
|
||||
"single_task": "Pick up the pan, rinse it in the sink and then place it in the drying rack.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_mobile_wipe_wine": {
|
||||
"single_task": "Pick up the wet cloth on the faucet and use it to clean the spilled wine on the table and underneath the glass.",
|
||||
**ALOHA_MOBILE_INFO,
|
||||
},
|
||||
"aloha_static_battery": {
|
||||
"single_task": "Place the battery into the slot of the remote controller.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO},
|
||||
"aloha_static_coffee": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_coffee_new": {
|
||||
"single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_cups_open": {
|
||||
"single_task": "Pick up the plastic cup and open its lid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_fork_pick_up": {
|
||||
"single_task": "Pick up the fork and place it on the plate.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pingpong_test": {
|
||||
"single_task": "Transfer one of the two balls in the right glass into the left glass, then transfer it back to the right glass.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_pro_pencil": {
|
||||
"single_task": "Pick up the pencil with the right arm, hand it over to the left arm then place it back onto the table.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_screw_driver": {
|
||||
"single_task": "Pick up the screwdriver with the right arm, hand it over to the left arm then place it into the cup.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_tape": {
|
||||
"single_task": "Cut a small piece of tape from the tape dispenser then place it on the cardboard box's edge.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_thread_velcro": {
|
||||
"single_task": "Pick up the velcro cable tie with the left arm, then insert the end of the velcro tie into the other end's loop with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_towel": {
|
||||
"single_task": "Pick up a piece of paper towel and place it on the spilled liquid.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_scripted_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO},
|
||||
"aloha_sim_insertion_human_image": {
|
||||
"single_task": "Insert the peg into the socket.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_scripted_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_sim_transfer_cube_human_image": {
|
||||
"single_task": "Pick up the cube with the right arm and transfer it to the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO},
|
||||
"unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO},
|
||||
"unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO},
|
||||
"unitreeh1_two_robot_greeting": {
|
||||
"single_task": "Greet the other robot with a high five.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"unitreeh1_warehouse": {
|
||||
"single_task": "Grab the spray paint on the shelf and place it in the bin on top of the robot dog.",
|
||||
**UNITREEH_INFO,
|
||||
},
|
||||
"xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO},
|
||||
"xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO},
|
||||
"umi_cup_in_the_wild": {
|
||||
"single_task": "Put the cup on the plate.",
|
||||
"license": "apache-2.0",
|
||||
},
|
||||
"asu_table_top": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://link.springer.com/article/10.1007/s10514-023-10129-1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023modularity,
|
||||
title={Modularity through Attention: Efficient Training and Transfer of Language-Conditioned Policies for Robot Manipulation},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Stepputtis, Simon and Amor, Heni},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={1684--1695},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}
|
||||
@article{zhou2023learning,
|
||||
title={Learning modular language-conditioned robot policies through attention},
|
||||
author={Zhou, Yifan and Sonawani, Shubham and Phielipp, Mariano and Ben Amor, Heni and Stepputtis, Simon},
|
||||
journal={Autonomous Robots},
|
||||
pages={1--21},
|
||||
year={2023},
|
||||
publisher={Springer}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_buds_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/BUDS-website/",
|
||||
"paper": "https://arxiv.org/abs/2109.13841",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022bottom,
|
||||
title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation},
|
||||
author={Zhu, Yifeng and Stone, Peter and Zhu, Yuke},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
volume={7},
|
||||
number={2},
|
||||
pages={4126--4133},
|
||||
year={2022},
|
||||
publisher={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sailor_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sailor/",
|
||||
"paper": "https://arxiv.org/abs/2210.11435",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{nasiriany2022sailor,
|
||||
title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning},
|
||||
author={Soroush Nasiriany and Tian Gao and Ajay Mandlekar and Yuke Zhu},
|
||||
booktitle={Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"austin_sirius_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/sirius/",
|
||||
"paper": "https://arxiv.org/abs/2211.08416",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{liu2022robot,
|
||||
title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment},
|
||||
author = {Huihan Liu and Soroush Nasiriany and Lance Zhang and Zhiyao Bao and Yuke Zhu},
|
||||
booktitle = {Robotics: Science and Systems (RSS)},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_autolab_ur5": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/berkeley-ur5/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{BerkeleyUR5Website,
|
||||
title = {Berkeley {UR5} Demonstration Dataset},
|
||||
author = {Lawrence Yunliang Chen and Simeon Adebola and Ken Goldberg},
|
||||
howpublished = {https://sites.google.com/view/berkeley-ur5/home},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_cable_routing": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://sites.google.com/view/cablerouting/home",
|
||||
"paper": "https://arxiv.org/abs/2307.08927",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2023multistage,
|
||||
author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine},
|
||||
title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning},
|
||||
journal = {arXiv pre-print},
|
||||
year = {2023},
|
||||
url = {https://arxiv.org/abs/2307.08927},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_fanuc_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/berkeley.edu/fanuc-manipulation",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{fanuc_manipulation2023,
|
||||
title={Fanuc Manipulation: A Dataset for Learning-based Manipulation with FANUC Mate 200iD Robot},
|
||||
author={Zhu, Xinghao and Tian, Ran and Xu, Chenfeng and Ding, Mingyu and Zhan, Wei and Tomizuka, Masayoshi},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_cory_hall": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/1709.10489",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{kahn2018self,
|
||||
title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation},
|
||||
author={Kahn, Gregory and Villaflor, Adam and Ding, Bosen and Abbeel, Pieter and Levine, Sergey},
|
||||
booktitle={2018 IEEE international conference on robotics and automation (ICRA)},
|
||||
pages={5129--5136},
|
||||
year={2018},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_recon": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/recon-robot",
|
||||
"paper": "https://arxiv.org/abs/2104.05859",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2021rapid,
|
||||
title={Rapid Exploration for Open-World Navigation with Latent Goal Models},
|
||||
author={Dhruv Shah and Benjamin Eysenbach and Nicholas Rhinehart and Sergey Levine},
|
||||
booktitle={5th Annual Conference on Robot Learning },
|
||||
year={2021},
|
||||
url={https://openreview.net/forum?id=d_SWJhyKfVw}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_gnm_sac_son": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/SACSoN-review",
|
||||
"paper": "https://arxiv.org/abs/2306.01874",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{hirose2023sacson,
|
||||
title={SACSoN: Scalable Autonomous Data Collection for Social Navigation},
|
||||
author={Hirose, Noriaki and Shah, Dhruv and Sridhar, Ajay and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2306.01874},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_mvp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2203.06173",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@InProceedings{Radosavovic2022,
|
||||
title = {Real-World Robot Learning with Masked Visual Pre-training},
|
||||
author = {Ilija Radosavovic and Tete Xiao and Stephen James and Pieter Abbeel and Jitendra Malik and Trevor Darrell},
|
||||
booktitle = {CoRL},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"berkeley_rpt": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://arxiv.org/abs/2306.10007",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{Radosavovic2023,
|
||||
title={Robot Learning with Sensorimotor Pre-training},
|
||||
author={Ilija Radosavovic and Baifeng Shi and Letian Fu and Ken Goldberg and Trevor Darrell and Jitendra Malik},
|
||||
year={2023},
|
||||
journal={arXiv:2306.10007}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_franka_exploration_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://human-world-model.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2308.10901",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={RSS},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_play_fusion": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-fusion.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2312.04549",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chen2023playfusion,
|
||||
title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play},
|
||||
author={Chen, Lili and Bahl, Shikhar and Pathak, Deepak},
|
||||
booktitle={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"cmu_stretch": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robo-affordances.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2304.08488",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{bahl2023affordances,
|
||||
title={Affordances from Human Videos as a Versatile Representation for Robotics},
|
||||
author={Bahl, Shikhar and Mendonca, Russell and Chen, Lili and Jain, Unnat and Pathak, Deepak},
|
||||
booktitle={CVPR},
|
||||
year={2023}
|
||||
}
|
||||
@article{mendonca2023structured,
|
||||
title={Structured World Models from Human Videos},
|
||||
author={Mendonca, Russell and Bahl, Shikhar and Pathak, Deepak},
|
||||
journal={CoRL},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"columbia_cairlab_pusht_real": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://diffusion-policy.cs.columbia.edu/",
|
||||
"paper": "https://arxiv.org/abs/2303.04137v5",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{chi2023diffusionpolicy,
|
||||
title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
author={Chi, Cheng and Feng, Siyuan and Du, Yilun and Xu, Zhenjia and Cousineau, Eric and Burchfiel, Benjamin and Song, Shuran},
|
||||
booktitle={Proceedings of Robotics: Science and Systems (RSS)},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"conq_hose_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/conq-hose-manipulation-dataset/home",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{ConqHoseManipData,
|
||||
author={Peter Mitrano and Dmitry Berenson},
|
||||
title={Conq Hose Manipulation Dataset, v1.15.0},
|
||||
year={2024},
|
||||
howpublished={https://sites.google.com/view/conq-hose-manipulation-dataset}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_edan_shared_control": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://ieeexplore.ieee.org/document/9341156",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{vogel_edan_2020,
|
||||
title = {EDAN - an EMG-Controlled Daily Assistant to Help People with Physical Disabilities},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE}/{RSJ} {International} {Conference} on {Intelligent} {Robots} and {Systems} ({IROS})},
|
||||
author = {Vogel, Jörn and Hagengruber, Annette and Iskandar, Maged and Quere, Gabriel and Leipscher, Ulrike and Bustamante, Samuel and Dietrich, Alexander and Hoeppner, Hannes and Leidner, Daniel and Albu-Schäffer, Alin},
|
||||
year = {2020}
|
||||
}
|
||||
@inproceedings{quere_shared_2020,
|
||||
address = {Paris, France},
|
||||
title = {Shared {Control} {Templates} for {Assistive} {Robotics}},
|
||||
language = {en},
|
||||
booktitle = {2020 {IEEE} {International} {Conference} on {Robotics} and {Automation} ({ICRA})},
|
||||
author = {Quere, Gabriel and Hagengruber, Annette and Iskandar, Maged and Bustamante, Samuel and Leidner, Daniel and Stulp, Freek and Vogel, Joern},
|
||||
year = {2020},
|
||||
pages = {7},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_grid_clamp": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://www.researchsquare.com/article/rs-3289569/v1",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{padalkar2023guided,
|
||||
title={A guided reinforcement learning approach using shared control templates for learning manipulation skills in the real world},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Raffin, Antonin and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
journal={Research square preprint rs-3289569/v1},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"dlr_sara_pour": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"paper": "https://elib.dlr.de/193739/1/padalkar2023rlsct.pdf",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{padalkar2023guiding,
|
||||
title={Guiding Reinforcement Learning with Shared Control Templates},
|
||||
author={Padalkar, Abhishek and Quere, Gabriel and Steinmetz, Franz and Raffin, Antonin and Nieuwenhuisen, Matthias and Silv{\'e}rio, Jo{\~a}o and Stulp, Freek},
|
||||
booktitle={40th IEEE International Conference on Robotics and Automation, ICRA 2023},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"droid_100": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://droid-dataset.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2403.12945",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{khazatsky2024droid,
|
||||
title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset},
|
||||
author = {Alexander Khazatsky and Karl Pertsch and Suraj Nair and Ashwin Balakrishna and Sudeep Dasari and Siddharth Karamcheti and Soroush Nasiriany and Mohan Kumar Srirama and Lawrence Yunliang Chen and Kirsty Ellis and Peter David Fagan and Joey Hejna and Masha Itkina and Marion Lepert and Yecheng Jason Ma and Patrick Tree Miller and Jimmy Wu and Suneel Belkhale and Shivin Dass and Huy Ha and Arhan Jain and Abraham Lee and Youngwoon Lee and Marius Memmel and Sungjae Park and Ilija Radosavovic and Kaiyuan Wang and Albert Zhan and Kevin Black and Cheng Chi and Kyle Beltran Hatch and Shan Lin and Jingpei Lu and Jean Mercat and Abdul Rehman and Pannag R Sanketi and Archit Sharma and Cody Simpson and Quan Vuong and Homer Rich Walke and Blake Wulfe and Ted Xiao and Jonathan Heewon Yang and Arefeh Yavary and Tony Z. Zhao and Christopher Agia and Rohan Baijal and Mateo Guaman Castro and Daphne Chen and Qiuyu Chen and Trinity Chung and Jaimyn Drake and Ethan Paul Foster and Jensen Gao and David Antonio Herrera and Minho Heo and Kyle Hsu and Jiaheng Hu and Donovon Jackson and Charlotte Le and Yunshuang Li and Kevin Lin and Roy Lin and Zehan Ma and Abhiram Maddukuri and Suvir Mirchandani and Daniel Morton and Tony Nguyen and Abigail O'Neill and Rosario Scalise and Derick Seale and Victor Son and Stephen Tian and Emi Tran and Andrew E. Wang and Yilin Wu and Annie Xie and Jingyun Yang and Patrick Yin and Yunchu Zhang and Osbert Bastani and Glen Berseth and Jeannette Bohg and Ken Goldberg and Abhinav Gupta and Abhishek Gupta and Dinesh Jayaraman and Joseph J Lim and Jitendra Malik and Roberto Martín-Martín and Subramanian Ramamoorthy and Dorsa Sadigh and Shuran Song and Jiajun Wu and Michael C. Yip and Yuke Zhu and Thomas Kollar and Sergey Levine and Chelsea Finn},
|
||||
year = {2024},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"fmb": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://functional-manipulation-benchmark.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2401.08553",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{luo2024fmb,
|
||||
title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning},
|
||||
author={Luo, Jianlan and Xu, Charles and Liu, Fangchen and Tan, Liam and Lin, Zipeng and Wu, Jeffrey and Abbeel, Pieter and Levine, Sergey},
|
||||
journal={arXiv preprint arXiv:2401.08553},
|
||||
year={2024}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"iamlab_cmu_pickup_insert": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://openreview.net/forum?id=WuBv9-IGDUA",
|
||||
"paper": "https://arxiv.org/abs/2401.14502",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{saxena2023multiresolution,
|
||||
title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models},
|
||||
author={Saumya Saxena and Mohit Sharma and Oliver Kroemer},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=WuBv9-IGDUA}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"imperialcollege_sawyer_wrist_cam": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
},
|
||||
"jaco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/clvrai/clvr_jaco_play_dataset",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@software{dass2023jacoplay,
|
||||
author = {Dass, Shivin and Yapeter, Jullian and Zhang, Jesse and Zhang, Jiahui
|
||||
and Pertsch, Karl and Nikolaidis, Stefanos and Lim, Joseph J.},
|
||||
title = {CLVR Jaco Play Dataset},
|
||||
url = {https://github.com/clvrai/clvr_jaco_play_dataset},
|
||||
version = {1.0.0},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"kaist_nonprehensile": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://github.com/JaeHyung-Kim/rlds_dataset_builder",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{kimpre,
|
||||
title={Pre-and post-contact policy decomposition for non-prehensile manipulation with zero-shot sim-to-real transfer},
|
||||
author={Kim, Minchan and Han, Junhyek and Kim, Jaehyung and Kim, Beomjoon},
|
||||
booktitle={2023 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
year={2023},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_door_opening_surprising_effectiveness": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://jyopari.github.io/VINN/",
|
||||
"paper": "https://arxiv.org/abs/2112.01511",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{pari2021surprising,
|
||||
title={The Surprising Effectiveness of Representation Learning for Visual Imitation},
|
||||
author={Jyothish Pari and Nur Muhammad Shafiullah and Sridhar Pandian Arunachalam and Lerrel Pinto},
|
||||
year={2021},
|
||||
eprint={2112.01511},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_franka_play_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://play-to-policy.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2210.10047",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{cui2022play,
|
||||
title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data},
|
||||
author = {Cui, Zichen Jeff and Wang, Yibin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal = {arXiv preprint arXiv:2210.10047},
|
||||
year = {2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"nyu_rot_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://rot-robot.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2206.15469",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{haldar2023watch,
|
||||
title={Watch and match: Supercharging imitation with regularized optimal transport},
|
||||
author={Haldar, Siddhant and Mathur, Vaibhav and Yarats, Denis and Pinto, Lerrel},
|
||||
booktitle={Conference on Robot Learning},
|
||||
pages={32--43},
|
||||
year={2023},
|
||||
organization={PMLR}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"roboturk": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://roboturk.stanford.edu/dataset_real.html",
|
||||
"paper": "PAPER",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{mandlekar2019scaling,
|
||||
title={Scaling robot supervision to hundreds of hours with roboturk: Robotic manipulation dataset through human reasoning and dexterity},
|
||||
author={Mandlekar, Ajay and Booher, Jonathan and Spero, Max and Tung, Albert and Gupta, Anchit and Zhu, Yuke and Garg, Animesh and Savarese, Silvio and Fei-Fei, Li},
|
||||
booktitle={2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
|
||||
pages={1048--1055},
|
||||
year={2019},
|
||||
organization={IEEE}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_hydra_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/hydra-il-2023",
|
||||
"paper": "https://arxiv.org/abs/2306.17237",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{belkhale2023hydra,
|
||||
title={HYDRA: Hybrid Robot Actions for Imitation Learning},
|
||||
author={Belkhale, Suneel and Cui, Yuchen and Sadigh, Dorsa},
|
||||
journal={arxiv},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_kuka_multimodal_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://sites.google.com/view/visionandtouch",
|
||||
"paper": "https://arxiv.org/abs/1810.10191",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{lee2019icra,
|
||||
title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks},
|
||||
author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette},
|
||||
booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2019},
|
||||
url={https://arxiv.org/abs/1810.10191}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"stanford_robocook": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://hshi74.github.io/robocook/",
|
||||
"paper": "https://arxiv.org/abs/2306.14447",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{shi2023robocook,
|
||||
title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools},
|
||||
author={Shi, Haochen and Xu, Huazhe and Clarke, Samuel and Li, Yunzhu and Wu, Jiajun},
|
||||
journal={arXiv preprint arXiv:2306.14447},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"taco_play": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"url": "https://www.kaggle.com/datasets/oiermees/taco-robot",
|
||||
"paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{rosete2022tacorl,
|
||||
author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard},
|
||||
title = {Latent Plans for Task Agnostic Offline Reinforcement Learning},
|
||||
journal = {Proceedings of the 6th Conference on Robot Learning (CoRL)},
|
||||
year = {2022}
|
||||
}
|
||||
@inproceedings{mees23hulc2,
|
||||
title={Grounding Language with Visual Affordances over Unstructured Data},
|
||||
author={Oier Mees and Jessica Borja-Diaz and Wolfram Burgard},
|
||||
booktitle = {Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
year={2023},
|
||||
address = {London, UK}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"tokyo_u_lsmo": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "URL",
|
||||
"paper": "https://arxiv.org/abs/2107.05842",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@Article{Osa22,
|
||||
author = {Takayuki Osa},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
title = {Motion Planning by Learning the Solution Manifold in Trajectory Optimization},
|
||||
year = {2022},
|
||||
number = {3},
|
||||
pages = {291--311},
|
||||
volume = {41},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"toto": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://toto-benchmark.org/",
|
||||
"paper": "https://arxiv.org/abs/2306.00942",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{zhou2023train,
|
||||
author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav},
|
||||
booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)},
|
||||
title={Train Offline, Test Online: A Real Robot Learning Benchmark},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_kitchen_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@ARTICLE{ucsd_kitchens,
|
||||
author = {Ge Yan, Kris Wu, and Xiaolong Wang},
|
||||
title = {{ucsd kitchens Dataset}},
|
||||
year = {2023},
|
||||
month = {August}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"ucsd_pick_and_place_dataset": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://owmcorl.github.io/#",
|
||||
"paper": "https://arxiv.org/abs/2310.16029",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@preprint{Feng2023Finetuning,
|
||||
title={Finetuning Offline World Models in the Real World},
|
||||
author={Yunhai Feng, Nicklas Hansen, Ziyan Xiong, Chandramouli Rajagopalan, Xiaolong Wang},
|
||||
year={2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"uiuc_d3field": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://robopil.github.io/d3fields/",
|
||||
"paper": "https://arxiv.org/abs/2309.16118",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{wang2023d3field,
|
||||
title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation},
|
||||
author={Wang, Yixuan and Li, Zhuoran and Zhang, Mingtong and Driggs-Campbell, Katherine and Wu, Jiajun and Fei-Fei, Li and Li, Yunzhu},
|
||||
journal={arXiv preprint arXiv:},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"usc_cloth_sim": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://uscresl.github.io/dmfd/",
|
||||
"paper": "https://arxiv.org/abs/2207.10148",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{salhotra2022dmfd,
|
||||
author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.},
|
||||
journal={IEEE Robotics and Automation Letters},
|
||||
title={Learning Deformable Object Manipulation From Expert Demonstrations},
|
||||
year={2022},
|
||||
volume={7},
|
||||
number={4},
|
||||
pages={8775-8782},
|
||||
doi={10.1109/LRA.2022.3187843}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utaustin_mutex": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/MUTEX/",
|
||||
"paper": "https://arxiv.org/abs/2309.14320",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@inproceedings{shah2023mutex,
|
||||
title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications},
|
||||
author={Rutav Shah and Roberto Mart{\'\i}n-Mart{\'\i}n and Yuke Zhu},
|
||||
booktitle={7th Annual Conference on Robot Learning},
|
||||
year={2023},
|
||||
url={https://openreview.net/forum?id=PwqiqaaEzJ}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_opening_fridge": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_pr2_tabletop_manipulation": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{oh2023pr2utokyodatasets,
|
||||
author={Jihoon Oh and Naoaki Kanazawa and Kento Kawaharazuka},
|
||||
title={X-Embodiment U-Tokyo PR2 Datasets},
|
||||
year={2023},
|
||||
url={https://github.com/ojh6404/rlds_dataset_builder},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_saytap": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://saytap.github.io/",
|
||||
"paper": "https://arxiv.org/abs/2306.07580",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{saytap2023,
|
||||
author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and
|
||||
Tatsuya Harada},
|
||||
title = {SayTap: Language to Quadrupedal Locomotion},
|
||||
eprint = {arXiv:2306.07580},
|
||||
url = {https://saytap.github.io},
|
||||
note = {https://saytap.github.io},
|
||||
year = {2023}
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_bimanual": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"utokyo_xarm_pick_and_place": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "cc-by-4.0",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@misc{matsushima2023weblab,
|
||||
title={Weblab xArm Dataset},
|
||||
author={Tatsuya Matsushima and Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo},
|
||||
year={2023},
|
||||
}""").lstrip(),
|
||||
},
|
||||
"viola": {
|
||||
"tasks_col": "language_instruction",
|
||||
"license": "mit",
|
||||
"url": "https://ut-austin-rpl.github.io/VIOLA/",
|
||||
"paper": "https://arxiv.org/abs/2210.11339",
|
||||
"citation_bibtex": dedent(r"""
|
||||
@article{zhu2022viola,
|
||||
title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors},
|
||||
author={Zhu, Yifeng and Joshi, Abhishek and Stone, Peter and Zhu, Yuke},
|
||||
journal={6th Annual Conference on Robot Learning (CoRL)},
|
||||
year={2022}
|
||||
}""").lstrip(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def batch_convert():
|
||||
status = {}
|
||||
logfile = LOCAL_DIR / "conversion_log.txt"
|
||||
for num, repo_id in enumerate(available_datasets):
|
||||
print(f"\nConverting {repo_id} ({num}/{len(available_datasets)})")
|
||||
assert set(DATASETS) == {id_.split("/")[1] for id_ in available_datasets}
|
||||
for num, (name, kwargs) in enumerate(DATASETS.items()):
|
||||
repo_id = f"lerobot/{name}"
|
||||
print(f"\nConverting {repo_id} ({num}/{len(DATASETS)})")
|
||||
print("---------------------------------------------------------")
|
||||
name = repo_id.split("/")[1]
|
||||
single_task, tasks_col, robot_config = None, None, None
|
||||
|
||||
if "aloha" in name:
|
||||
robot_config = parse_robot_config(ALOHA_CONFIG)
|
||||
if "sim_insertion" in name:
|
||||
single_task = "Insert the peg into the socket."
|
||||
elif "sim_transfer" in name:
|
||||
single_task = "Pick up the cube with the right arm and transfer it to the left arm."
|
||||
else:
|
||||
single_task = ALOHA_SINGLE_TASKS_REAL[name]
|
||||
elif "unitreeh1" in name:
|
||||
if "fold_clothes" in name:
|
||||
single_task = "Fold the sweatshirt."
|
||||
elif "rearrange_objects" in name or "rearrange_objects" in name:
|
||||
single_task = "Put the object into the bin."
|
||||
elif "two_robot_greeting" in name:
|
||||
single_task = "Greet the other robot with a high five."
|
||||
elif "warehouse" in name:
|
||||
single_task = (
|
||||
"Grab the spray paint on the shelf and place it in the bin on top of the robot dog."
|
||||
)
|
||||
elif name != "columbia_cairlab_pusht_real" and "pusht" in name:
|
||||
single_task = "Push the T-shaped block onto the T-shaped target."
|
||||
elif "xarm_lift" in name or "xarm_push" in name:
|
||||
single_task = "Pick up the cube and lift it."
|
||||
elif name == "umi_cup_in_the_wild":
|
||||
single_task = "Put the cup on the plate."
|
||||
else:
|
||||
tasks_col = "language_instruction"
|
||||
|
||||
try:
|
||||
convert_dataset(
|
||||
repo_id=repo_id,
|
||||
local_dir=LOCAL_DIR,
|
||||
single_task=single_task,
|
||||
tasks_col=tasks_col,
|
||||
robot_config=robot_config,
|
||||
)
|
||||
convert_dataset(repo_id, LOCAL_DIR, **kwargs)
|
||||
status = f"{repo_id}: success."
|
||||
with open(logfile, "a") as file:
|
||||
file.write(status + "\n")
|
||||
|
||||
@@ -103,21 +103,20 @@ import argparse
|
||||
import contextlib
|
||||
import filecmp
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import shutil
|
||||
import subprocess
|
||||
import warnings
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError
|
||||
from PIL import Image
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
@@ -132,9 +131,16 @@ from lerobot.common.datasets.utils import (
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_hub_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
write_json,
|
||||
write_jsonlines,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import (
|
||||
VideoFrame, # noqa: F401
|
||||
get_image_pixel_channels,
|
||||
get_video_info,
|
||||
)
|
||||
from lerobot.common.datasets.video_utils import VideoFrame # noqa: F401
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
|
||||
V16 = "v1.6"
|
||||
@@ -170,28 +176,12 @@ def parse_robot_config(config_path: Path, config_overrides: list[str] | None = N
|
||||
"robot_type": robot_cfg["robot_type"],
|
||||
"names": {
|
||||
"observation.state": state_names,
|
||||
"observation.effort": state_names,
|
||||
"action": action_names,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_json(fpath: Path) -> dict:
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_jsonlines(data: dict, fpath: Path) -> None:
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(data)
|
||||
|
||||
|
||||
def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
safetensor_path = v1_dir / V1_STATS_PATH
|
||||
stats = load_file(safetensor_path)
|
||||
@@ -213,21 +203,40 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
torch.testing.assert_close(stats_json[key], stats[key])
|
||||
|
||||
|
||||
def get_keys(dataset: Dataset) -> dict[str, list]:
|
||||
sequence_keys, image_keys, video_keys = [], [], []
|
||||
def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]:
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
dtype = ft.dtype
|
||||
shape = (1,)
|
||||
names = None
|
||||
if isinstance(ft, datasets.Sequence):
|
||||
sequence_keys.append(key)
|
||||
assert isinstance(ft.feature, datasets.Value)
|
||||
dtype = ft.feature.dtype
|
||||
shape = (ft.length,)
|
||||
motor_names = (
|
||||
robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)]
|
||||
)
|
||||
assert len(motor_names) == shape[0]
|
||||
names = {"motors": motor_names}
|
||||
elif isinstance(ft, datasets.Image):
|
||||
image_keys.append(key)
|
||||
dtype = "image"
|
||||
image = dataset[0][key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
shape = (image.height, image.width, channels)
|
||||
names = ["height", "width", "channel"]
|
||||
elif ft._type == "VideoFrame":
|
||||
video_keys.append(key)
|
||||
dtype = "video"
|
||||
shape = None # Add shape later
|
||||
names = ["height", "width", "channel"]
|
||||
|
||||
return {
|
||||
"sequence": sequence_keys,
|
||||
"image": image_keys,
|
||||
"video": video_keys,
|
||||
}
|
||||
features[key] = {
|
||||
"dtype": dtype,
|
||||
"shape": shape,
|
||||
"names": names,
|
||||
}
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]:
|
||||
@@ -270,17 +279,15 @@ def add_task_index_from_tasks_col(
|
||||
|
||||
def split_parquet_by_episodes(
|
||||
dataset: Dataset,
|
||||
keys: dict[str, list],
|
||||
total_episodes: int,
|
||||
total_chunks: int,
|
||||
output_dir: Path,
|
||||
) -> list:
|
||||
table = dataset.remove_columns(keys["video"])._data.table
|
||||
table = dataset.data.table
|
||||
episode_lengths = []
|
||||
for ep_chunk in range(total_chunks):
|
||||
ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk
|
||||
ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes)
|
||||
|
||||
chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk)
|
||||
(output_dir / chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
for ep_idx in range(ep_chunk_start, ep_chunk_end):
|
||||
@@ -406,156 +413,23 @@ def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[st
|
||||
return [f for f in video_files if f not in lfs_tracked_files]
|
||||
|
||||
|
||||
def _get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def _get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**_get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict:
|
||||
hub_api = HfApi()
|
||||
videos_info_dict = {"videos_path": DEFAULT_VIDEO_PATH}
|
||||
|
||||
# Assumes first episode
|
||||
video_files = [
|
||||
DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0)
|
||||
for vid_key in video_keys
|
||||
]
|
||||
hub_api = HfApi()
|
||||
hub_api.snapshot_download(
|
||||
repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files
|
||||
)
|
||||
videos_info_dict = {}
|
||||
for vid_key, vid_path in zip(video_keys, video_files, strict=True):
|
||||
videos_info_dict[vid_key] = _get_video_info(local_dir / vid_path)
|
||||
videos_info_dict[vid_key] = get_video_info(local_dir / vid_path)
|
||||
|
||||
return videos_info_dict
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_video_shapes(videos_info: dict, video_keys: list) -> dict:
|
||||
video_shapes = {}
|
||||
for img_key in video_keys:
|
||||
channels = get_video_pixel_channels(videos_info[img_key]["video.pix_fmt"])
|
||||
video_shapes[img_key] = {
|
||||
"width": videos_info[img_key]["video.width"],
|
||||
"height": videos_info[img_key]["video.height"],
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return video_shapes
|
||||
|
||||
|
||||
def get_image_shapes(dataset: Dataset, image_keys: list) -> dict:
|
||||
image_shapes = {}
|
||||
for img_key in image_keys:
|
||||
image = dataset[0][img_key] # Assuming first row
|
||||
channels = get_image_pixel_channels(image)
|
||||
image_shapes[img_key] = {
|
||||
"width": image.width,
|
||||
"height": image.height,
|
||||
"channels": channels,
|
||||
}
|
||||
|
||||
return image_shapes
|
||||
|
||||
|
||||
def get_generic_motor_names(sequence_shapes: dict) -> dict:
|
||||
return {key: [f"motor_{i}" for i in range(length)] for key, length in sequence_shapes.items()}
|
||||
|
||||
|
||||
def convert_dataset(
|
||||
repo_id: str,
|
||||
local_dir: Path,
|
||||
@@ -564,8 +438,9 @@ def convert_dataset(
|
||||
tasks_col: Path | None = None,
|
||||
robot_config: dict | None = None,
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_hub_safe_version(repo_id, V16, enforce_v2=False)
|
||||
v1 = get_hub_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -582,12 +457,12 @@ def convert_dataset(
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
keys = get_keys(dataset)
|
||||
features = get_features_from_hf_dataset(dataset, robot_config)
|
||||
video_keys = [key for key, ft in features.items() if ft["dtype"] == "video"]
|
||||
|
||||
if single_task and "language_instruction" in dataset.column_names:
|
||||
warnings.warn(
|
||||
logging.warning(
|
||||
"'single_task' provided but 'language_instruction' tasks_col found. Using 'language_instruction'.",
|
||||
stacklevel=1,
|
||||
)
|
||||
single_task = None
|
||||
tasks_col = "language_instruction"
|
||||
@@ -596,7 +471,7 @@ def convert_dataset(
|
||||
episode_indices = sorted(dataset.unique("episode_index"))
|
||||
total_episodes = len(episode_indices)
|
||||
assert episode_indices == list(range(total_episodes))
|
||||
total_videos = total_episodes * len(keys["video"])
|
||||
total_videos = total_episodes * len(video_keys)
|
||||
total_chunks = total_episodes // DEFAULT_CHUNK_SIZE
|
||||
if total_episodes % DEFAULT_CHUNK_SIZE != 0:
|
||||
total_chunks += 1
|
||||
@@ -609,7 +484,6 @@ def convert_dataset(
|
||||
elif tasks_path:
|
||||
tasks_by_episodes = load_json(tasks_path)
|
||||
tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()}
|
||||
# tasks = list(set(tasks_by_episodes.values()))
|
||||
dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes)
|
||||
tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()}
|
||||
elif tasks_col:
|
||||
@@ -620,56 +494,50 @@ def convert_dataset(
|
||||
assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks}
|
||||
tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)]
|
||||
write_jsonlines(tasks, v20_dir / TASKS_PATH)
|
||||
|
||||
# Shapes
|
||||
sequence_shapes = {key: dataset.features[key].length for key in keys["sequence"]}
|
||||
image_shapes = get_image_shapes(dataset, keys["image"]) if len(keys["image"]) > 0 else {}
|
||||
features["task_index"] = {
|
||||
"dtype": "int64",
|
||||
"shape": (1,),
|
||||
"names": None,
|
||||
}
|
||||
|
||||
# Videos
|
||||
if len(keys["video"]) > 0:
|
||||
if video_keys:
|
||||
assert metadata_v1.get("video", False)
|
||||
tmp_video_dir = local_dir / "videos" / V20 / repo_id
|
||||
tmp_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
dataset = dataset.remove_columns(video_keys)
|
||||
clean_gitattr = Path(
|
||||
hub_api.hf_hub_download(
|
||||
repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes"
|
||||
)
|
||||
).absolute()
|
||||
move_videos(
|
||||
repo_id, keys["video"], total_episodes, total_chunks, tmp_video_dir, clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=keys["video"], branch=branch)
|
||||
video_shapes = get_video_shapes(videos_info, keys["video"])
|
||||
for img_key in keys["video"]:
|
||||
assert math.isclose(videos_info[img_key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
with tempfile.TemporaryDirectory() as tmp_video_dir:
|
||||
move_videos(
|
||||
repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch
|
||||
)
|
||||
videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch)
|
||||
for key in video_keys:
|
||||
features[key]["shape"] = (
|
||||
videos_info[key].pop("video.height"),
|
||||
videos_info[key].pop("video.width"),
|
||||
videos_info[key].pop("video.channels"),
|
||||
)
|
||||
features[key]["video_info"] = videos_info[key]
|
||||
assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3)
|
||||
if "encoding" in metadata_v1:
|
||||
assert videos_info[img_key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"]
|
||||
else:
|
||||
assert metadata_v1.get("video", 0) == 0
|
||||
videos_info = None
|
||||
video_shapes = {}
|
||||
|
||||
# Split data into 1 parquet file by episode
|
||||
episode_lengths = split_parquet_by_episodes(dataset, keys, total_episodes, total_chunks, v20_dir)
|
||||
episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir)
|
||||
|
||||
# Names
|
||||
if robot_config is not None:
|
||||
robot_type = robot_config["robot_type"]
|
||||
names = robot_config["names"]
|
||||
if "observation.effort" in keys["sequence"]:
|
||||
names["observation.effort"] = names["observation.state"]
|
||||
if "observation.velocity" in keys["sequence"]:
|
||||
names["observation.velocity"] = names["observation.state"]
|
||||
repo_tags = [robot_type]
|
||||
else:
|
||||
robot_type = "unknown"
|
||||
names = get_generic_motor_names(sequence_shapes)
|
||||
repo_tags = None
|
||||
|
||||
assert set(names) == set(keys["sequence"])
|
||||
for key in sequence_shapes:
|
||||
assert len(names[key]) == sequence_shapes[key]
|
||||
|
||||
# Episodes
|
||||
episodes = [
|
||||
{"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]}
|
||||
@@ -680,7 +548,6 @@ def convert_dataset(
|
||||
# Assemble metadata v2.0
|
||||
metadata_v2_0 = {
|
||||
"codebase_version": V20,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": len(dataset),
|
||||
@@ -690,23 +557,21 @@ def convert_dataset(
|
||||
"chunks_size": DEFAULT_CHUNK_SIZE,
|
||||
"fps": metadata_v1["fps"],
|
||||
"splits": {"train": f"0:{total_episodes}"},
|
||||
"keys": keys["sequence"],
|
||||
"video_keys": keys["video"],
|
||||
"image_keys": keys["image"],
|
||||
"shapes": {**sequence_shapes, **video_shapes, **image_shapes},
|
||||
"names": names,
|
||||
"videos": videos_info,
|
||||
"data_path": DEFAULT_PARQUET_PATH,
|
||||
"video_path": DEFAULT_VIDEO_PATH if video_keys else None,
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch)
|
||||
|
||||
with contextlib.suppress(EntryNotFoundError):
|
||||
with contextlib.suppress(EntryNotFoundError, HfHubHTTPError):
|
||||
hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch)
|
||||
|
||||
hub_api.upload_folder(
|
||||
@@ -724,28 +589,11 @@ def convert_dataset(
|
||||
revision=branch,
|
||||
)
|
||||
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, info=metadata_v2_0)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
if not test_branch:
|
||||
create_branch(repo_id=repo_id, branch=V20, repo_type="dataset")
|
||||
|
||||
# TODO:
|
||||
# - [X] Add shapes
|
||||
# - [X] Add keys
|
||||
# - [X] Add paths
|
||||
# - [X] convert stats.json
|
||||
# - [X] Add task.json
|
||||
# - [X] Add names
|
||||
# - [X] Add robot_type
|
||||
# - [X] Add splits
|
||||
# - [X] Push properly to branch v2.0 and delete v1.6 stuff from that branch
|
||||
# - [X] Handle multitask datasets
|
||||
# - [X] Handle hf hub repo limits (add chunks logic)
|
||||
# - [X] Add test-branch
|
||||
# - [X] Use jsonlines for episodes
|
||||
# - [X] Add sanity checks (encoding, shapes)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -790,6 +638,12 @@ def main():
|
||||
default=None,
|
||||
help="Local directory to store the dataset during conversion. Defaults to /tmp/lerobot_dataset_v2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--license",
|
||||
type=str,
|
||||
default="apache-2.0",
|
||||
help="Repo license. Must be one of https://huggingface.co/docs/hub/repositories-licenses. Defaults to mit.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-branch",
|
||||
type=str,
|
||||
@@ -808,7 +662,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from time import sleep
|
||||
|
||||
sleep(1)
|
||||
main()
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import warnings
|
||||
@@ -25,6 +26,7 @@ import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def decode_video_frames_torchvision(
|
||||
@@ -210,3 +212,104 @@ with warnings.catch_warnings():
|
||||
)
|
||||
# to make VideoFrame available in HuggingFace `datasets`
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
ffprobe_audio_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"a:0",
|
||||
"-show_entries",
|
||||
"stream=channels,codec_name,bit_rate,sample_rate,bit_depth,channel_layout,duration",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
audio_stream_info = info["streams"][0] if info.get("streams") else None
|
||||
if audio_stream_info is None:
|
||||
return {"has_audio": False}
|
||||
|
||||
# Return the information, defaulting to None if no audio stream is present
|
||||
return {
|
||||
"has_audio": True,
|
||||
"audio.channels": audio_stream_info.get("channels", None),
|
||||
"audio.codec": audio_stream_info.get("codec_name", None),
|
||||
"audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None,
|
||||
"audio.sample_rate": int(audio_stream_info["sample_rate"])
|
||||
if audio_stream_info.get("sample_rate")
|
||||
else None,
|
||||
"audio.bit_depth": audio_stream_info.get("bit_depth", None),
|
||||
"audio.channel_layout": audio_stream_info.get("channel_layout", None),
|
||||
}
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
ffprobe_video_cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-select_streams",
|
||||
"v:0",
|
||||
"-show_entries",
|
||||
"stream=r_frame_rate,width,height,codec_name,nb_frames,duration,pix_fmt",
|
||||
"-of",
|
||||
"json",
|
||||
str(video_path),
|
||||
]
|
||||
result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Error running ffprobe: {result.stderr}")
|
||||
|
||||
info = json.loads(result.stdout)
|
||||
video_stream_info = info["streams"][0]
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
r_frame_rate = video_stream_info["r_frame_rate"]
|
||||
num, denom = map(int, r_frame_rate.split("/"))
|
||||
fps = num / denom
|
||||
|
||||
pixel_channels = get_video_pixel_channels(video_stream_info["pix_fmt"])
|
||||
|
||||
video_info = {
|
||||
"video.fps": fps,
|
||||
"video.height": video_stream_info["height"],
|
||||
"video.width": video_stream_info["width"],
|
||||
"video.channels": pixel_channels,
|
||||
"video.codec": video_stream_info["codec_name"],
|
||||
"video.pix_fmt": video_stream_info["pix_fmt"],
|
||||
"video.is_depth_map": False,
|
||||
**get_audio_info(video_path),
|
||||
}
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
elif "rgba" in pix_fmt or "yuva" in pix_fmt:
|
||||
return 4
|
||||
elif "rgb" in pix_fmt or "yuv" in pix_fmt:
|
||||
return 3
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
|
||||
def get_image_pixel_channels(image: Image):
|
||||
if image.mode == "L":
|
||||
return 1 # Grayscale
|
||||
elif image.mode == "LA":
|
||||
return 2 # Grayscale + Alpha
|
||||
elif image.mode == "RGB":
|
||||
return 3 # RGB
|
||||
elif image.mode == "RGBA":
|
||||
return 4 # RGBA
|
||||
else:
|
||||
raise ValueError("Unknown format")
|
||||
|
||||
@@ -67,6 +67,7 @@ class DiffusionConfig:
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
The group sizes are set to be about 16 (to be precise, feature_dim // 16).
|
||||
spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax.
|
||||
use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view.
|
||||
down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet.
|
||||
You may provide a variable number of dimensions, therefore also controlling the degree of
|
||||
downsampling.
|
||||
@@ -130,6 +131,7 @@ class DiffusionConfig:
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
spatial_softmax_num_keypoints: int = 32
|
||||
use_separate_rgb_encoder_per_camera: bool = False
|
||||
# Unet.
|
||||
down_dims: tuple[int, ...] = (512, 1024, 2048)
|
||||
kernel_size: int = 5
|
||||
|
||||
@@ -182,8 +182,13 @@ class DiffusionModel(nn.Module):
|
||||
self._use_env_state = False
|
||||
if num_images > 0:
|
||||
self._use_images = True
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
encoders = [DiffusionRgbEncoder(config) for _ in range(num_images)]
|
||||
self.rgb_encoder = nn.ModuleList(encoders)
|
||||
global_cond_dim += encoders[0].feature_dim * num_images
|
||||
else:
|
||||
self.rgb_encoder = DiffusionRgbEncoder(config)
|
||||
global_cond_dim += self.rgb_encoder.feature_dim * num_images
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self._use_env_state = True
|
||||
global_cond_dim += config.input_shapes["observation.environment_state"][0]
|
||||
@@ -239,16 +244,32 @@ class DiffusionModel(nn.Module):
|
||||
"""Encode image features and concatenate them all together along with the state vector."""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
global_cond_feats = [batch["observation.state"]]
|
||||
# Extract image feature (first combine batch, sequence, and camera index dims).
|
||||
# Extract image features.
|
||||
if self._use_images:
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
if self.config.use_separate_rgb_encoder_per_camera:
|
||||
# Combine batch and sequence dims while rearranging to make the camera index dimension first.
|
||||
images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...")
|
||||
img_features_list = torch.cat(
|
||||
[
|
||||
encoder(images)
|
||||
for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True)
|
||||
]
|
||||
)
|
||||
# Separate batch and sequence dims back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
else:
|
||||
# Combine batch, sequence, and "which camera" dims before passing to shared encoder.
|
||||
img_features = self.rgb_encoder(
|
||||
einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...")
|
||||
)
|
||||
# Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the
|
||||
# feature dim (effectively concatenating the camera features).
|
||||
img_features = einops.rearrange(
|
||||
img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps
|
||||
)
|
||||
global_cond_feats.append(img_features)
|
||||
|
||||
if self._use_env_state:
|
||||
|
||||
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
30
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
23
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
discount = 0.99
|
||||
156
lerobot/common/policies/sac/modeling_sac.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
|
||||
self.critic_ensemble = ...
|
||||
self.critic_target = ...
|
||||
self.actor_network = ...
|
||||
|
||||
self.temperature = ...
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch['observations'])###
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
observation_batch =
|
||||
next_obaservation_batch =
|
||||
action_batch =
|
||||
reward_batch =
|
||||
dones_batch =
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
next_actions = ..
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_obaservation_batch, next_actions)
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# backup entropy
|
||||
td_target = reward_batch + self.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observation_batch, action_batch)
|
||||
|
||||
# 4- Calculate loss
|
||||
critics_loss = F.mse_loss(q_preds,
|
||||
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])) # dones masks
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observation_batch)
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observation_batch, actions, return_type="mean")
|
||||
actor_loss = -(q_preds - temperature * log_probs).mean()
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = temperature * (entropy - self.target_entropy).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"Q_value_loss": critics_loss.item(),
|
||||
"pi_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
@@ -203,8 +203,7 @@ class IntelRealSenseCamera:
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py \
|
||||
--images-dir outputs/images_from_intelrealsense_cameras
|
||||
python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras
|
||||
```
|
||||
|
||||
When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
|
||||
@@ -219,8 +219,7 @@ class OpenCVCamera:
|
||||
|
||||
To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera:
|
||||
```bash
|
||||
python lerobot/common/robot_devices/cameras/opencv.py \
|
||||
--images-dir outputs/images_from_opencv_cameras
|
||||
python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras
|
||||
```
|
||||
|
||||
When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode
|
||||
@@ -328,7 +327,7 @@ class OpenCVCamera:
|
||||
if self.camera_index not in available_cam_ids:
|
||||
raise ValueError(
|
||||
f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. "
|
||||
"To find the camera index you should use, run `python lerobot/lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
"To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`."
|
||||
)
|
||||
|
||||
raise OSError(f"Can't access OpenCVCamera({camera_idx}).")
|
||||
|
||||
@@ -13,10 +13,12 @@ from functools import cache
|
||||
import cv2
|
||||
import torch
|
||||
import tqdm
|
||||
from deepdiff import DeepDiff
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import get_features_from_robot
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
@@ -326,7 +328,36 @@ def sanity_check_dataset_name(repo_id, policy):
|
||||
_, dataset_name = repo_id.split("/")
|
||||
# either repo_id doesnt start with "eval_" and there is no policy
|
||||
# or repo_id starts with "eval_" and there is a policy
|
||||
if dataset_name.startswith("eval_") == (policy is None):
|
||||
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins by 'eval_' ({dataset_name}) but no policy is provided ({policy})."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
if not dataset_name.startswith("eval_") and policy is not None:
|
||||
raise ValueError(
|
||||
f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy})."
|
||||
)
|
||||
|
||||
|
||||
def sanity_check_dataset_robot_compatibility(
|
||||
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
|
||||
) -> None:
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
("features", dataset.features, get_features_from_robot(robot, use_videos)),
|
||||
]
|
||||
|
||||
mismatches = []
|
||||
for field, dataset_value, present_value in fields:
|
||||
diff = DeepDiff(dataset_value, present_value)
|
||||
if diff:
|
||||
mismatches.append(f"{field}: expected {present_value}, got {dataset_value}")
|
||||
|
||||
if mismatches:
|
||||
raise ValueError(
|
||||
"Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches)
|
||||
)
|
||||
|
||||
@@ -298,16 +298,6 @@ class FeetechMotorsBus:
|
||||
self.logs = {}
|
||||
|
||||
self.track_positions = {}
|
||||
self.present_pos = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
"below_zero": [None] * len(self.motor_names),
|
||||
"above_max": [None] * len(self.motor_names),
|
||||
}
|
||||
self.goal_pos = {
|
||||
"prev": [None] * len(self.motor_names),
|
||||
"below_zero": [None] * len(self.motor_names),
|
||||
"above_max": [None] * len(self.motor_names),
|
||||
}
|
||||
|
||||
def connect(self):
|
||||
if self.is_connected:
|
||||
|
||||
@@ -64,7 +64,7 @@ def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=N
|
||||
# print(f"{present_voltage=}")
|
||||
# print(f"{present_temperature=}")
|
||||
|
||||
if present_speed == 0 and present_current > 50:
|
||||
if present_speed == 0 and present_current > 40:
|
||||
count += 1
|
||||
if count > 100 or present_current > 300:
|
||||
return present_pos
|
||||
@@ -306,16 +306,16 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
calib = {}
|
||||
|
||||
print("Calibrate shoulder_pan")
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan", load_threshold=350, count_threshold=200)
|
||||
calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan")
|
||||
arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan")
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate gripper")
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True, count_threshold=200)
|
||||
calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True)
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_flex")
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True, count_threshold=200)
|
||||
calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True)
|
||||
calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024)
|
||||
|
||||
wr_pos = arm.read("Present_Position", "wrist_roll")
|
||||
@@ -329,7 +329,7 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
time.sleep(1)
|
||||
|
||||
print("Calibrate wrist_roll")
|
||||
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True, count_threshold=200)
|
||||
calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True)
|
||||
calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790)
|
||||
|
||||
arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll")
|
||||
@@ -348,7 +348,6 @@ def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str
|
||||
arm,
|
||||
"elbow_flex",
|
||||
invert_drive_mode=True,
|
||||
count_threshold=200,
|
||||
in_between_move_hook=in_between_move_elbow_flex_hook,
|
||||
)
|
||||
arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex")
|
||||
|
||||
@@ -226,13 +226,42 @@ class ManipulatorRobot:
|
||||
self.is_connected = False
|
||||
self.logs = {}
|
||||
|
||||
action_names = [f"{arm}_{motor}" for arm, bus in self.leader_arms.items() for motor in bus.motors]
|
||||
state_names = [f"{arm}_{motor}" for arm, bus in self.follower_arms.items() for motor in bus.motors]
|
||||
self.names = {
|
||||
"action": action_names,
|
||||
"observation.state": state_names,
|
||||
def get_motor_names(self, arm: dict[str, MotorsBus]) -> list:
|
||||
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||
|
||||
@property
|
||||
def camera_features(self) -> dict:
|
||||
cam_ft = {}
|
||||
for cam_key, cam in self.cameras.items():
|
||||
key = f"observation.images.{cam_key}"
|
||||
cam_ft[key] = {
|
||||
"shape": (cam.height, cam.width, cam.channels),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
}
|
||||
return cam_ft
|
||||
|
||||
@property
|
||||
def motor_features(self) -> dict:
|
||||
action_names = self.get_motor_names(self.leader_arms)
|
||||
state_names = self.get_motor_names(self.leader_arms)
|
||||
return {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(action_names),),
|
||||
"names": action_names,
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(state_names),),
|
||||
"names": state_names,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def features(self):
|
||||
return {**self.motor_features, **self.camera_features}
|
||||
|
||||
@property
|
||||
def has_camera(self):
|
||||
return len(self.cameras) > 0
|
||||
@@ -271,7 +300,7 @@ class ManipulatorRobot:
|
||||
print(f"Connecting {name} leader arm.")
|
||||
self.leader_arms[name].connect()
|
||||
|
||||
if self.robot_type in ["koch", "aloha"]:
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.motors.dynamixel import TorqueMode
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.motors.feetech import TorqueMode
|
||||
@@ -286,7 +315,7 @@ class ManipulatorRobot:
|
||||
self.activate_calibration()
|
||||
|
||||
# Set robot preset (e.g. torque in leader gripper for Koch v1.1)
|
||||
if self.robot_type == "koch":
|
||||
if self.robot_type in ["koch", "koch_bimanual"]:
|
||||
self.set_koch_robot_preset()
|
||||
elif self.robot_type == "aloha":
|
||||
self.set_aloha_robot_preset()
|
||||
@@ -299,7 +328,7 @@ class ManipulatorRobot:
|
||||
self.follower_arms[name].write("Torque_Enable", 1)
|
||||
|
||||
if self.config.gripper_open_degree is not None:
|
||||
if self.robot_type in ["aloha", "so100", "moss"]:
|
||||
if self.robot_type not in ["koch", "koch_bimanual"]:
|
||||
raise NotImplementedError(
|
||||
f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open."
|
||||
)
|
||||
@@ -335,26 +364,20 @@ class ManipulatorRobot:
|
||||
with open(arm_calib_path) as f:
|
||||
calibration = json.load(f)
|
||||
else:
|
||||
# TODO(rcadene): display a warning in __init__ if calibration file not available
|
||||
print(f"Missing calibration file '{arm_calib_path}'")
|
||||
|
||||
if self.robot_type in ["koch", "aloha"]:
|
||||
if self.robot_type in ["koch", "koch_bimanual", "aloha"]:
|
||||
from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration
|
||||
|
||||
calibration = run_arm_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
elif self.robot_type in ["so100", "moss"]:
|
||||
from lerobot.common.robot_devices.robots.feetech_calibration import (
|
||||
run_arm_auto_calibration,
|
||||
run_arm_manual_calibration,
|
||||
)
|
||||
|
||||
# TODO(rcadene): better way to handle mocking + test run_arm_auto_calibration
|
||||
if arm_type == "leader" or arm.mock:
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
elif arm_type == "follower":
|
||||
calibration = run_arm_auto_calibration(arm, self.robot_type, name, arm_type)
|
||||
else:
|
||||
raise ValueError(arm_type)
|
||||
calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type)
|
||||
|
||||
print(f"Calibration is done! Saving calibration file '{arm_calib_path}'")
|
||||
arm_calib_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -11,6 +11,7 @@ def get_arm_id(name, arm_type):
|
||||
class Robot(Protocol):
|
||||
# TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes
|
||||
robot_type: str
|
||||
features: dict
|
||||
|
||||
def connect(self): ...
|
||||
def run_calibration(self): ...
|
||||
|
||||
10
lerobot/configs/env/moss_real.yaml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
10
lerobot/configs/env/so100_real.yaml
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# @package _global_
|
||||
|
||||
fps: 30
|
||||
|
||||
env:
|
||||
name: real_world
|
||||
task: null
|
||||
state_dim: 6
|
||||
action_dim: 6
|
||||
fps: ${fps}
|
||||
@@ -114,7 +114,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
@@ -95,7 +95,7 @@ policy:
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_momentum: null
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
|
||||
102
lerobot/configs/policy/act_moss_real.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/moss_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
102
lerobot/configs/policy/act_so100_real.yaml
Normal file
@@ -0,0 +1,102 @@
|
||||
# @package _global_
|
||||
|
||||
# Use `act_koch_real.yaml` to train on real-world datasets collected on Alexander Koch's robots.
|
||||
# Compared to `act.yaml`, it contains 2 cameras (i.e. laptop, phone) instead of 1 camera (i.e. top).
|
||||
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
|
||||
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
|
||||
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
|
||||
#
|
||||
# Example of usage for training:
|
||||
# ```bash
|
||||
# python lerobot/scripts/train.py \
|
||||
# policy=act_koch_real \
|
||||
# env=koch_real
|
||||
# ```
|
||||
|
||||
seed: 1000
|
||||
dataset_repo_id: lerobot/so100_pick_place_lego
|
||||
|
||||
override_dataset_stats:
|
||||
observation.images.laptop:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
observation.images.phone:
|
||||
# stats from imagenet, since we use a pretrained vision model
|
||||
mean: [[[0.485]], [[0.456]], [[0.406]]] # (c,1,1)
|
||||
std: [[[0.229]], [[0.224]], [[0.225]]] # (c,1,1)
|
||||
|
||||
training:
|
||||
offline_steps: 80000
|
||||
online_steps: 0
|
||||
eval_freq: -1
|
||||
save_freq: 10000
|
||||
log_freq: 100
|
||||
save_checkpoint: true
|
||||
|
||||
batch_size: 8
|
||||
lr: 1e-5
|
||||
lr_backbone: 1e-5
|
||||
weight_decay: 1e-4
|
||||
grad_clip_norm: 10
|
||||
online_steps_between_rollouts: 1
|
||||
|
||||
delta_timestamps:
|
||||
action: "[i / ${fps} for i in range(${policy.chunk_size})]"
|
||||
|
||||
eval:
|
||||
n_episodes: 50
|
||||
batch_size: 50
|
||||
|
||||
# See `configuration_act.py` for more details.
|
||||
policy:
|
||||
name: act
|
||||
|
||||
# Input / output structure.
|
||||
n_obs_steps: 1
|
||||
chunk_size: 100
|
||||
n_action_steps: 100
|
||||
|
||||
input_shapes:
|
||||
# TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
|
||||
observation.images.laptop: [3, 480, 640]
|
||||
observation.images.phone: [3, 480, 640]
|
||||
observation.state: ["${env.state_dim}"]
|
||||
output_shapes:
|
||||
action: ["${env.action_dim}"]
|
||||
|
||||
# Normalization / Unnormalization
|
||||
input_normalization_modes:
|
||||
observation.images.laptop: mean_std
|
||||
observation.images.phone: mean_std
|
||||
observation.state: mean_std
|
||||
output_normalization_modes:
|
||||
action: mean_std
|
||||
|
||||
# Architecture.
|
||||
# Vision backbone.
|
||||
vision_backbone: resnet18
|
||||
pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
|
||||
replace_final_stride_with_dilation: false
|
||||
# Transformer layers.
|
||||
pre_norm: false
|
||||
dim_model: 512
|
||||
n_heads: 8
|
||||
dim_feedforward: 3200
|
||||
feedforward_activation: relu
|
||||
n_encoder_layers: 4
|
||||
# Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
|
||||
# that means only the first layer is used. Here we match the original implementation by setting this to 1.
|
||||
# See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
|
||||
n_decoder_layers: 1
|
||||
# VAE.
|
||||
use_vae: true
|
||||
latent_dim: 32
|
||||
n_vae_encoder_layers: 4
|
||||
|
||||
# Inference.
|
||||
temporal_ensemble_coeff: null
|
||||
|
||||
# Training and loss computation.
|
||||
dropout: 0.1
|
||||
kl_weight: 10.0
|
||||
@@ -1,5 +1,5 @@
|
||||
_target_: lerobot.common.robot_devices.robots.manipulator.ManipulatorRobot
|
||||
robot_type: koch
|
||||
robot_type: koch_bimanual
|
||||
calibration_dir: .cache/calibration/koch_bimanual
|
||||
|
||||
# `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes.
|
||||
|
||||
@@ -115,6 +115,7 @@ from lerobot.common.robot_devices.control_utils import (
|
||||
record_episode,
|
||||
reset_environment,
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
stop_recording,
|
||||
warmup_record,
|
||||
)
|
||||
@@ -207,8 +208,9 @@ def record(
|
||||
num_image_writer_threads_per_camera: int = 4,
|
||||
display_cameras: bool = True,
|
||||
play_sounds: bool = True,
|
||||
tags: str = None,
|
||||
force_override: bool = False,
|
||||
resume: bool = False,
|
||||
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
listener = None
|
||||
@@ -234,17 +236,29 @@ def record(
|
||||
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
|
||||
)
|
||||
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads_per_camera=num_image_writer_threads_per_camera,
|
||||
use_videos=video,
|
||||
)
|
||||
if resume:
|
||||
dataset = LeRobotDataset(
|
||||
repo_id,
|
||||
root=root,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
dataset.start_image_writer(
|
||||
num_processes=num_image_writer_processes,
|
||||
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
sanity_check_dataset_name(repo_id, policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id,
|
||||
fps,
|
||||
root=root,
|
||||
robot=robot,
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
@@ -272,8 +286,7 @@ def record(
|
||||
# if multi_task:
|
||||
# task = input("Enter your task description: ")
|
||||
|
||||
episode_index = dataset.episode_buffer["episode_index"]
|
||||
log_say(f"Recording episode {episode_index}", play_sounds)
|
||||
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
|
||||
record_episode(
|
||||
dataset=dataset,
|
||||
robot=robot,
|
||||
@@ -291,7 +304,7 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(episode_index < num_episodes - 1) or events["rerecord_episode"]
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
@@ -303,7 +316,7 @@ def record(
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.add_episode(task)
|
||||
dataset.save_episode(task)
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
@@ -312,16 +325,11 @@ def record(
|
||||
log_say("Stop recording", play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, display_cameras)
|
||||
|
||||
if dataset.image_writer is not None:
|
||||
logging.info("Waiting for image writer to terminate...")
|
||||
dataset.image_writer.stop()
|
||||
|
||||
if run_compute_stats:
|
||||
logging.info("Computing dataset statistics")
|
||||
|
||||
dataset.consolidate(run_compute_stats)
|
||||
|
||||
# lerobot_dataset = create_lerobot_dataset(dataset, run_compute_stats, push_to_hub, tags, play_sounds)
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -349,7 +357,7 @@ def replay(
|
||||
robot.connect()
|
||||
|
||||
log_say("Replaying episode", play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_samples):
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
|
||||
@@ -484,7 +484,7 @@ def main(
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
|
||||
else:
|
||||
# Note: We need the dataset stats to pass to the policy's normalization modules.
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).stats)
|
||||
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
policy.eval()
|
||||
|
||||
@@ -1,30 +1,36 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from serial.tools import list_ports # Part of pyserial library
|
||||
|
||||
|
||||
def find_available_ports():
|
||||
ports = []
|
||||
for path in Path("/dev").glob("tty*"):
|
||||
ports.append(str(path))
|
||||
if os.name == "nt": # Windows
|
||||
# List COM ports using pyserial
|
||||
ports = [port.device for port in list_ports.comports()]
|
||||
else: # Linux/macOS
|
||||
# List /dev/tty* ports for Unix-based systems
|
||||
ports = [str(path) for path in Path("/dev").glob("tty*")]
|
||||
return ports
|
||||
|
||||
|
||||
def find_port():
|
||||
print("Finding all available ports for the MotorsBus.")
|
||||
ports_before = find_available_ports()
|
||||
print(ports_before)
|
||||
print("Ports before disconnecting:", ports_before)
|
||||
|
||||
print("Remove the usb cable from your MotorsBus and press Enter when done.")
|
||||
input()
|
||||
print("Remove the USB cable from your MotorsBus and press Enter when done.")
|
||||
input() # Wait for user to disconnect the device
|
||||
|
||||
time.sleep(0.5)
|
||||
time.sleep(0.5) # Allow some time for port to be released
|
||||
ports_after = find_available_ports()
|
||||
ports_diff = list(set(ports_before) - set(ports_after))
|
||||
|
||||
if len(ports_diff) == 1:
|
||||
port = ports_diff[0]
|
||||
print(f"The port of this MotorsBus is '{port}'")
|
||||
print("Reconnect the usb cable.")
|
||||
print("Reconnect the USB cable.")
|
||||
elif len(ports_diff) == 0:
|
||||
raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).")
|
||||
else:
|
||||
@@ -32,5 +38,5 @@ def find_port():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Helper to find the usb port associated to all your MotorsBus.
|
||||
# Helper to find the USB port associated with your MotorsBus.
|
||||
find_port()
|
||||
|
||||
@@ -117,10 +117,14 @@ def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str
|
||||
|
||||
|
||||
def push_dataset_card_to_hub(
|
||||
repo_id: str, revision: str | None, tags: list | None = None, text: str | None = None
|
||||
repo_id: str,
|
||||
revision: str | None,
|
||||
tags: list | None = None,
|
||||
license: str = "apache-2.0",
|
||||
**card_kwargs,
|
||||
):
|
||||
"""Creates and pushes a LeRobotDataset Card with appropriate tags to easily find it on the hub."""
|
||||
card = create_lerobot_dataset_card(tags=tags, text=text)
|
||||
card = create_lerobot_dataset_card(tags=tags, license=license, **card_kwargs)
|
||||
card.push_to_hub(repo_id=repo_id, repo_type="dataset", revision=revision)
|
||||
|
||||
|
||||
|
||||
@@ -171,9 +171,9 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
@@ -208,9 +208,9 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
|
||||
num_samples = (step + 1) * cfg.training.batch_size
|
||||
avg_samples_per_ep = dataset.num_samples / dataset.num_episodes
|
||||
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
|
||||
num_episodes = num_samples / avg_samples_per_ep
|
||||
num_epochs = num_samples / dataset.num_samples
|
||||
num_epochs = num_samples / dataset.num_frames
|
||||
log_items = [
|
||||
f"step:{format_big_number(step)}",
|
||||
# number of samples seen during training
|
||||
@@ -328,7 +328,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("make_policy")
|
||||
policy = make_policy(
|
||||
hydra_cfg=cfg,
|
||||
dataset_stats=offline_dataset.stats if not cfg.resume else None,
|
||||
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
)
|
||||
assert isinstance(policy, nn.Module)
|
||||
@@ -349,7 +349,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
|
||||
logging.info(f"{cfg.training.online_steps=}")
|
||||
logging.info(f"{offline_dataset.num_samples=} ({format_big_number(offline_dataset.num_samples)})")
|
||||
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
|
||||
logging.info(f"{offline_dataset.num_episodes=}")
|
||||
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
@@ -573,7 +573,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
|
||||
online_sampling_ratio=cfg.training.online_sampling_ratio,
|
||||
)
|
||||
sampler.num_samples = len(concat_dataset)
|
||||
sampler.num_frames = len(concat_dataset)
|
||||
|
||||
update_online_buffer_s = time.perf_counter() - start_update_buffer_time
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
|
||||
|
||||
def visualize_dataset(
|
||||
repo_id: str,
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
batch_size: int = 32,
|
||||
num_workers: int = 0,
|
||||
@@ -108,7 +108,6 @@ def visualize_dataset(
|
||||
web_port: int = 9090,
|
||||
ws_port: int = 9087,
|
||||
save: bool = False,
|
||||
root: Path | None = None,
|
||||
output_dir: Path | None = None,
|
||||
) -> Path | None:
|
||||
if save:
|
||||
@@ -116,8 +115,7 @@ def visualize_dataset(
|
||||
output_dir is not None
|
||||
), "Set an output directory where to write .rrd files with `--output-dir path/to/directory`."
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
@@ -153,7 +151,7 @@ def visualize_dataset(
|
||||
rr.set_time_seconds("timestamp", batch["timestamp"][i].item())
|
||||
|
||||
# display each camera image
|
||||
for key in dataset.camera_keys:
|
||||
for key in dataset.meta.camera_keys:
|
||||
# TODO(rcadene): add `.compress()`? is it lossless?
|
||||
rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i])))
|
||||
|
||||
@@ -268,7 +266,14 @@ def main():
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset(**vars(args))
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
|
||||
logging.info("Loading dataset")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||
|
||||
visualize_dataset(dataset, **vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -93,12 +93,12 @@ def run_server(
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_samples,
|
||||
"num_samples": dataset.num_frames,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
video_paths = [dataset.get_video_file_path(episode_id, key) for key in dataset.video_keys]
|
||||
tasks = dataset.episode_dicts[episode_id]["tasks"]
|
||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
@@ -130,16 +130,16 @@ def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
has_state = "observation.state" in dataset.hf_dataset.features
|
||||
has_action = "action" in dataset.hf_dataset.features
|
||||
has_state = "observation.state" in dataset.features
|
||||
has_action = "action" in dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.shapes["observation.state"]
|
||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.shapes["action"]
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
columns = ["timestamp"]
|
||||
@@ -170,27 +170,13 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
# get first frame of episode (hack to get video_path of the episode)
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
return [
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"] for key in dataset.video_keys
|
||||
dataset.hf_dataset.select_columns(key)[first_frame_idx][key]["path"]
|
||||
for key in dataset.meta.video_keys
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.hf_dataset.features:
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
repo_id: str,
|
||||
root: Path | None = None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
@@ -200,13 +186,11 @@ def visualize_dataset_html(
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=root)
|
||||
|
||||
if len(dataset.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.image_keys=}) are currently not supported.")
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = f"outputs/visualize_dataset_html/{repo_id}"
|
||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
@@ -295,7 +279,11 @@ def main():
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
visualize_dataset_html(**vars(args))
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
root = kwargs.pop("root")
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=True)
|
||||
visualize_dataset_html(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -157,7 +157,7 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get 1st frame from 1st camera of 1st episode
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
to_pil(original_frame).save(output_dir / "original_frame.png", quality=100)
|
||||
print("\nOriginal frame saved to:")
|
||||
print(f" {output_dir / 'original_frame.png'}.")
|
||||
|
||||
@@ -35,7 +35,7 @@
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
|
||||
BIN
media/gym/aloha_act.gif
Normal file
|
After Width: | Height: | Size: 2.9 MiB |
BIN
media/gym/pusht_diffusion.gif
Normal file
|
After Width: | Height: | Size: 185 KiB |
BIN
media/gym/simxarm_tdmpc.gif
Normal file
|
After Width: | Height: | Size: 464 KiB |
BIN
media/moss/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 153 KiB |
BIN
media/moss/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 208 KiB |
BIN
media/moss/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 296 KiB |
BIN
media/so100/follower_rest.webp
Normal file
|
After Width: | Height: | Size: 145 KiB |
BIN
media/so100/follower_rotated.webp
Normal file
|
After Width: | Height: | Size: 95 KiB |
BIN
media/so100/follower_zero.webp
Normal file
|
After Width: | Height: | Size: 134 KiB |
BIN
media/so100/leader_follower.webp
Normal file
|
After Width: | Height: | Size: 117 KiB |
47
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -935,36 +935,29 @@ tests = ["pytest", "pytest-cov", "pytest-xdist"]
|
||||
|
||||
[[package]]
|
||||
name = "dash"
|
||||
version = "2.18.1"
|
||||
version = "2.9.3"
|
||||
description = "A Python framework for building reactive web-apps. Developed by Plotly."
|
||||
optional = true
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.6"
|
||||
files = [
|
||||
{file = "dash-2.18.1-py3-none-any.whl", hash = "sha256:07c4513bb5f79a4b936847a0b49afc21dbd4b001ff77ea78d4d836043e211a07"},
|
||||
{file = "dash-2.18.1.tar.gz", hash = "sha256:ffdf89690d734f6851ef1cb344222826ffb11ad2214ab9172668bf8aadd75d12"},
|
||||
{file = "dash-2.9.3-py3-none-any.whl", hash = "sha256:a749ae1ea9de3fe7b785353a818ec9b629d39c6b7e02462954203bd1e296fd0e"},
|
||||
{file = "dash-2.9.3.tar.gz", hash = "sha256:47392f8d6455dc989a697407eb5941f3bad80604df985ab1ac9d4244568ffb34"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
dash-core-components = "2.0.0"
|
||||
dash-html-components = "2.0.0"
|
||||
dash-table = "5.0.0"
|
||||
Flask = ">=1.0.4,<3.1"
|
||||
importlib-metadata = "*"
|
||||
nest-asyncio = "*"
|
||||
Flask = ">=1.0.4"
|
||||
plotly = ">=5.0.0"
|
||||
requests = "*"
|
||||
retrying = "*"
|
||||
setuptools = "*"
|
||||
typing-extensions = ">=4.1.1"
|
||||
Werkzeug = "<3.1"
|
||||
|
||||
[package.extras]
|
||||
celery = ["celery[redis] (>=5.1.2)", "redis (>=3.5.3)"]
|
||||
ci = ["black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==7.0.0)", "flaky (==3.8.1)", "flask-talisman (==1.0.0)", "jupyterlab (<4.0.0)", "mimesis (<=11.1.0)", "mock (==4.0.3)", "numpy (<=1.26.3)", "openpyxl", "orjson (==3.10.3)", "pandas (>=1.4.0)", "pyarrow", "pylint (==3.0.3)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "pyzmq (==25.1.2)", "xlrd (>=2.0.1)"]
|
||||
celery = ["celery[redis] (>=5.1.2)", "importlib-metadata (<5)", "redis (>=3.5.3)"]
|
||||
ci = ["black (==21.6b0)", "black (==22.3.0)", "dash-dangerously-set-inner-html", "dash-flow-example (==0.0.5)", "flake8 (==3.9.2)", "flaky (==3.7.0)", "flask-talisman (==1.0.0)", "isort (==4.3.21)", "mimesis", "mock (==4.0.3)", "numpy", "openpyxl", "orjson (==3.5.4)", "orjson (==3.6.7)", "pandas (==1.1.5)", "pandas (>=1.4.0)", "preconditions", "pyarrow", "pyarrow (<3)", "pylint (==2.13.5)", "pytest-mock", "pytest-rerunfailures", "pytest-sugar (==0.9.6)", "xlrd (<2)", "xlrd (>=2.0.1)"]
|
||||
compress = ["flask-compress"]
|
||||
dev = ["PyYAML (>=5.4.1)", "coloredlogs (>=15.0.1)", "fire (>=0.4.0)"]
|
||||
diskcache = ["diskcache (>=5.2.1)", "multiprocess (>=0.70.12)", "psutil (>=5.8.0)"]
|
||||
testing = ["beautifulsoup4 (>=4.8.2)", "cryptography", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"]
|
||||
testing = ["beautifulsoup4 (>=4.8.2)", "cryptography (<3.4)", "dash-testing-stub (>=0.0.2)", "lxml (>=4.6.2)", "multiprocess (>=0.70.12)", "percy (>=2.0.2)", "psutil (>=5.8.0)", "pytest (>=6.0.2)", "requests[security] (>=2.21.0)", "selenium (>=3.141.0,<=4.2.0)", "waitress (>=1.4.4)"]
|
||||
|
||||
[[package]]
|
||||
name = "dash-core-components"
|
||||
@@ -5920,20 +5913,6 @@ typing-extensions = ">=4.5"
|
||||
notebook = ["rerun-notebook (==0.18.2)"]
|
||||
tests = ["pytest (==7.1.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "retrying"
|
||||
version = "1.3.4"
|
||||
description = "Retrying"
|
||||
optional = true
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "retrying-1.3.4-py3-none-any.whl", hash = "sha256:8cc4d43cb8e1125e0ff3344e9de678fefd85db3b750b81b2240dc0183af37b35"},
|
||||
{file = "retrying-1.3.4.tar.gz", hash = "sha256:345da8c5765bd982b1d1915deb9102fd3d1f7ad16bd84a9700b85f64d24e8f3e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
six = ">=1.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "rfc3339-validator"
|
||||
version = "0.1.4"
|
||||
@@ -7258,13 +7237,13 @@ test = ["websockets"]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.0.4"
|
||||
version = "3.1.1"
|
||||
description = "The comprehensive WSGI web application library."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "werkzeug-3.0.4-py3-none-any.whl", hash = "sha256:02c9eb92b7d6c06f31a782811505d2157837cea66aaede3e217c7c27c039476c"},
|
||||
{file = "werkzeug-3.0.4.tar.gz", hash = "sha256:34f2371506b250df4d4f84bfe7b0921e4762525762bbd936614909fe25cd7306"},
|
||||
{file = "werkzeug-3.1.1-py3-none-any.whl", hash = "sha256:a71124d1ef06008baafa3d266c02f56e1836a5984afd6dd6c9230669d60d9fb5"},
|
||||
{file = "werkzeug-3.1.1.tar.gz", hash = "sha256:8cd39dfbdfc1e051965f156163e2974e52c210f130810e9ad36858f0fd3edad4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
||||
@@ -23,6 +23,13 @@ from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
]
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
print(f"\nTesting with {DEVICE=}")
|
||||
|
||||
29
tests/fixtures/constants.py
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||
DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
"video.fps": DEFAULT_FPS,
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
396
tests/fixtures/dataset_factories.py
vendored
Normal file
@@ -0,0 +1,396 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
DUMMY_VIDEO_INFO,
|
||||
)
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
return img_array
|
||||
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot_type: str = DUMMY_ROBOT_TYPE,
|
||||
total_episodes: int = 0,
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
def _create_stats(
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
dtype = ft["dtype"]
|
||||
if dtype in ["image", "video"]:
|
||||
stats[key] = {
|
||||
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
||||
}
|
||||
return stats
|
||||
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks_list = []
|
||||
for i in range(total_tasks):
|
||||
task_dict = {"task_index": i, "task": f"Perform action {i}."}
|
||||
tasks_list.append(task_dict)
|
||||
return tasks_list
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
raise ValueError("num_episodes and total_length must be positive integers.")
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
num_tasks_available = len(tasks_list)
|
||||
|
||||
episodes_list = []
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes_list.append(
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
)
|
||||
|
||||
return episodes_list
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
features = features_factory()
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
|
||||
robot_cols = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
{
|
||||
**robot_cols,
|
||||
"timestamp": timestamp_col,
|
||||
"frame_index": frame_index_col,
|
||||
"episode_index": episode_index_col,
|
||||
"index": index_col,
|
||||
"task_index": task_index,
|
||||
},
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
return dataset
|
||||
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
hf_dataset=hf_dataset,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
114
tests/fixtures/files.py
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_info_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_stats_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes)
|
||||
return fpath
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return fpath
|
||||
|
||||
return _create_single_episode_parquet
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
total_episodes = info["total_episodes"]
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
|
||||
return _create_multi_episode_parquet
|
||||
105
tests/fixtures/hub.py
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
hf_dataset_factory,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
|
||||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
return _mock_snapshot_download_func
|
||||
@@ -76,7 +76,7 @@ def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
)
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=5,
|
||||
fps=1,
|
||||
warmup_time_s=0.5,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
@@ -155,18 +155,24 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.total_episodes == 2
|
||||
assert len(dataset) == 10
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=5, root=root, repo_id=repo_id, play_sounds=False)
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
|
||||
# TODO(rcadene, aliberts): rethink this design
|
||||
if robot_type == "aloha":
|
||||
env_name = "aloha_real"
|
||||
policy_name = "act_aloha_real"
|
||||
elif robot_type in ["koch", "koch_bimanual", "so100", "moss"]:
|
||||
elif robot_type in ["koch", "koch_bimanual"]:
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
elif robot_type == "so100":
|
||||
env_name = "so100_real"
|
||||
policy_name = "act_so100_real"
|
||||
elif robot_type == "moss":
|
||||
env_name = "moss_real"
|
||||
policy_name = "act_moss_real"
|
||||
else:
|
||||
raise NotImplementedError(robot_type)
|
||||
|
||||
@@ -187,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
@@ -269,22 +275,25 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
record_kwargs = {
|
||||
"robot": robot,
|
||||
"root": root,
|
||||
"repo_id": repo_id,
|
||||
"single_task": single_task,
|
||||
"fps": 1,
|
||||
"warmup_time_s": 0,
|
||||
"episode_time_s": 1,
|
||||
"push_to_hub": False,
|
||||
"video": False,
|
||||
"display_cameras": False,
|
||||
"play_sounds": False,
|
||||
"run_compute_stats": False,
|
||||
"local_files_only": True,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
|
||||
dataset = record(**record_kwargs)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
# init_dataset_return_value = {}
|
||||
|
||||
@@ -294,22 +303,13 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
# return init_dataset_return_value
|
||||
|
||||
# with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(**record_kwargs)
|
||||
|
||||
dataset = record(**record_kwargs, resume=True)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
# assert (
|
||||
# init_dataset_return_value["num_episodes"] == 2
|
||||
# ), "`init_dataset` should load the previous episode"
|
||||
|
||||
@@ -33,39 +33,72 @@ from lerobot.common.datasets.compute_stats import (
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||
|
||||
# TODO(aliberts): create proper test repo
|
||||
TEST_REPO_ID = "aliberts/koch_tutorial"
|
||||
|
||||
|
||||
def test_same_attributes_defined():
|
||||
# TODO(aliberts): test with keys, shapes, names etc. provided instead of robot
|
||||
robot = make_robot("koch", mock=True)
|
||||
|
||||
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
dataset_init = LeRobotDataset(repo_id=TEST_REPO_ID)
|
||||
dataset_create = LeRobotDataset.create(repo_id=TEST_REPO_ID, fps=30, robot=robot)
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init._hub_version
|
||||
_ = dataset_create._hub_version
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
assert init_attr == create_attr, "Attribute sets do not match between __init__ and .create()"
|
||||
assert init_attr == create_attr
|
||||
|
||||
|
||||
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
kwargs = {
|
||||
"repo_id": DUMMY_REPO_ID,
|
||||
"total_episodes": 10,
|
||||
"total_frames": 400,
|
||||
"episodes": [2, 5, 6],
|
||||
}
|
||||
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_frame
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test consolidate
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
@@ -88,7 +121,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.camera_keys
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -138,6 +171,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
@@ -151,7 +185,7 @@ def test_multilerobotdataset_frames():
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
@@ -170,6 +204,8 @@ def test_multilerobotdataset_frames():
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
@@ -218,7 +254,7 @@ def test_compute_stats_on_xarm():
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
loaded_stats = dataset.meta.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
@@ -229,72 +265,7 @@ def test_compute_stats_on_xarm():
|
||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_within_tolerance():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
with pytest.raises(AssertionError):
|
||||
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
@@ -318,6 +289,7 @@ def test_flatten_unflatten_dict():
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -389,6 +361,7 @@ def test_backward_compatibility(repo_id):
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
|
||||
256
tests/test_delta_timestamps.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_hf_dataset_factory(hf_dataset_factory):
|
||||
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
|
||||
return hf_dataset_factory(fps=fps)
|
||||
|
||||
return _create_synced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just outside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_unsynced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just inside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_slightly_off_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
delta_timestamps[key][3] += tolerance_s * 1.1
|
||||
return delta_timestamps
|
||||
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
delta_timestamps[key][3] += tolerance_s * 0.9
|
||||
delta_timestamps[key][-3] += tolerance_s * 0.9
|
||||
return delta_timestamps
|
||||
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
|
||||
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=single_timestamp_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
|
||||
@pytest.mark.skip("TODO: fix")
|
||||
def test_check_timestamps_sync_empty_dataset():
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
|
||||
empty_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"to": torch.tensor([], dtype=torch.int64),
|
||||
"from": torch.tensor([], dtype=torch.int64),
|
||||
}
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=empty_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=valid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_delta_timestamps_empty():
|
||||
delta_timestamps = {}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
|
||||
fps = 30
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
expected_delta_indices = delta_indices
|
||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
||||
assert expected_delta_indices == actual_delta_indices
|
||||
@@ -13,12 +13,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO(aliberts): Mute logging for these tests
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -29,6 +32,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
||||
return text
|
||||
|
||||
|
||||
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
|
||||
def _run_script(path):
|
||||
subprocess.run([sys.executable, path], check=True)
|
||||
|
||||
@@ -38,12 +42,26 @@ def _read_file(path):
|
||||
return file.read()
|
||||
|
||||
|
||||
def test_example_1():
|
||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
||||
def test_example_1(tmp_path, lerobot_dataset_factory):
|
||||
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
|
||||
path = "examples/1_load_lerobot_dataset.py"
|
||||
_run_script(path)
|
||||
file_contents = _read_file(path)
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
|
||||
(
|
||||
"LeRobotDataset(repo_id",
|
||||
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
|
||||
),
|
||||
],
|
||||
)
|
||||
exec(file_contents, {})
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
||||
@require_package("gym_pusht")
|
||||
def test_examples_basic2_basic3_advanced1():
|
||||
"""
|
||||
@@ -111,7 +129,8 @@ def test_examples_basic2_basic3_advanced1():
|
||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
|
||||
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
|
||||
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
|
||||
@@ -15,15 +15,12 @@
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
@@ -33,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID)
|
||||
return dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_random():
|
||||
return torch.rand(3, 480, 640)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
@@ -67,47 +49,54 @@ def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img):
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img), img)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img, min_max):
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img, min_max):
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img, min_max):
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img, min_max):
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img, min_max):
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img):
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
@@ -125,12 +114,13 @@ def test_get_image_transforms_max_num_transforms(img):
|
||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img):
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
@@ -141,13 +131,14 @@ def test_get_image_transforms_random_order(img):
|
||||
)
|
||||
with seeded_context(1337):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img))
|
||||
out_imgs.append(tf(img_tensor))
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
[
|
||||
@@ -158,21 +149,24 @@ def test_get_image_transforms_random_order(img):
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
actual = tf(img)
|
||||
actual = tf(img_tensor)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img, default_transforms):
|
||||
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
@@ -191,7 +185,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img)
|
||||
actual = default_tf(img_tensor)
|
||||
|
||||
expected = default_transforms["default"]
|
||||
|
||||
@@ -199,33 +193,36 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(p, img):
|
||||
def test_random_subset_apply_single_choice(img_tensor_factory, p):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img)
|
||||
actual = random_choice(img_tensor)
|
||||
|
||||
p_horz, _ = p
|
||||
if p_horz:
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
|
||||
else:
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img))
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img):
|
||||
def test_random_subset_apply_random_order(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||
actual = random_order(img)
|
||||
expected = v2.Compose(flips)(img)
|
||||
actual = random_order(img_tensor)
|
||||
expected = v2.Compose(flips)(img_tensor)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
||||
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
|
||||
img_tensor = img_tensor_factory()
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img)
|
||||
assert output.shape == img.shape
|
||||
output = transform(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||
@@ -239,16 +236,18 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img):
|
||||
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img):
|
||||
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_min_negative():
|
||||
@@ -261,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
|
||||
359
tests/test_image_writer.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import queue
|
||||
import time
|
||||
from multiprocessing import queues
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
)
|
||||
|
||||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
|
||||
def test_init_threading():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 0
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queue.Queue)
|
||||
assert len(writer.threads) == 2
|
||||
assert len(writer.processes) == 0
|
||||
assert all(t.is_alive() for t in writer.threads)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_init_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 2
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queues.JoinableQueue)
|
||||
assert len(writer.threads) == 0
|
||||
assert len(writer.processes) == 2
|
||||
assert all(p.is_alive() for p in writer.processes)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_zero_threads():
|
||||
with pytest.raises(ValueError):
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_image_rgb(img_array_factory):
|
||||
img_array = img_array_factory(100, 100)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
def test_image_array_to_image_pytorch_format(img_array_factory):
|
||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: implement")
|
||||
def test_image_array_to_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
|
||||
|
||||
def test_image_array_to_image_float_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
|
||||
|
||||
def test_image_array_to_image_out_of_bounds_float():
|
||||
# Float array with values out of [0, 1]
|
||||
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
|
||||
|
||||
|
||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
|
||||
|
||||
def test_write_image_image(tmp_path, img_factory):
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_pil, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_invalid_data(tmp_path):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
time.sleep(1)
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_stop():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_stop_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(p.is_alive() for p in writer.processes)
|
||||
|
||||
|
||||
def test_multiple_stops():
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_multiple_stops_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_exception_handling(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
with (
|
||||
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
|
||||
pytest.raises(queue.Full) as exc_info,
|
||||
):
|
||||
writer.save_image(image_array, tmp_path / "test.png")
|
||||
assert str(exc_info.value) == "Queue is full"
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
formats = ["png", "jpeg", "bmp"]
|
||||
for fmt in formats:
|
||||
fpath = tmp_path / f"test_image.{fmt}"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
raise Exception("Test exception")
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
function_that_raises_exception(dataset=dataset)
|
||||
|
||||
assert str(exc_info.value) == "Test exception"
|
||||
dataset.image_writer.stop.assert_called_once()
|
||||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
start_time = time.perf_counter()
|
||||
writer.save_image(image_tensor, fpath)
|
||||
end_time = time.perf_counter()
|
||||
time_spent = end_time - start_time
|
||||
# Might need to adjust this threshold depending on hardware
|
||||
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -19,11 +19,8 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
@@ -212,29 +209,17 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
|
||||
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
|
||||
lerobot_dataset_factory,
|
||||
tmp_path,
|
||||
offline_dataset_size: int,
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
# Pass/skip the test if both datasets sizes are zero.
|
||||
if offline_dataset_size + online_dataset_size == 0:
|
||||
return
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(
|
||||
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
|
||||
)
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
if offline_dataset_size == 0:
|
||||
offline_dataset.episode_data_index = {}
|
||||
else:
|
||||
# Set up an episode_data_index with at least two episodes.
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, offline_dataset_size // 2]),
|
||||
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
@@ -254,16 +239,9 @@ def test_compute_sampler_weights_trivial(
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio():
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, 2]),
|
||||
"to": torch.tensor([2, 4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
@@ -275,16 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio():
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
@@ -295,18 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames():
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
data_dict = {
|
||||
"timestamp": [0, 0.1],
|
||||
"index": [0, 1],
|
||||
"episode_index": [0, 0],
|
||||
"frame_index": [0, 1],
|
||||
}
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
|
||||
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
@@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
[
|
||||
|
||||
@@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
|
||||
@@ -39,7 +39,6 @@ def test_robot(tmpdir, request, robot_type, mock):
|
||||
# TODO(rcadene): measure fps in nightly?
|
||||
# TODO(rcadene): test logs
|
||||
# TODO(rcadene): add compatibility with other robots
|
||||
|
||||
robot_kwargs = {"robot_type": robot_type}
|
||||
|
||||
if robot_type == "aloha" and mock:
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
@@ -73,20 +72,6 @@ def test_calculate_episode_data_index():
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_reset_episode_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [10, 10, 11, 12, 12, 12],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
correct_episode_index = [0, 0, 1, 2, 2, 2]
|
||||
dataset = reset_episode_index(dataset)
|
||||
assert dataset["episode_index"] == correct_episode_index
|
||||
|
||||
|
||||
def test_init_hydra_config_empty():
|
||||
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
|
||||
with open(test_file, "w") as f:
|
||||
|
||||
@@ -13,25 +13,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
@pytest.mark.skip("TODO: add dummy videos")
|
||||
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
dataset,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
@@ -14,23 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
|
||||