Compare commits

..

13 Commits

Author SHA1 Message Date
KeWang
def42ff487 Port SAC WIP (#581)
Co-authored-by: KeWang1017 <ke.wang@helloleap.ai>
2024-12-17 16:16:59 +01:00
Michel Aractingi
c9af8e36a7 completed losses 2024-12-17 16:16:36 +01:00
Michel Aractingi
ed66c92383 nit in control_robot.py 2024-12-17 11:04:56 +07:00
Michel Aractingi
668d493bf9 Update lerobot/scripts/train_hilserl_classifier.py
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2024-12-17 02:44:31 +07:00
Claudio Coppola
67f4d7ea7a LerobotDataset pushable to HF from any folder (#563) 2024-12-17 02:44:23 +07:00
berjaoui
4b0c88ff8e Update 7_get_started_with_real_robot.md (#559) 2024-12-17 02:44:11 +07:00
Michel Aractingi
b19fef9d18 Control simulated robot with real leader (#514)
Co-authored-by: Remi <remi.cadene@huggingface.co>
2024-12-17 02:44:03 +07:00
Remi
1612e00e63 Fix missing local_files_only in record/replay (#540)
Co-authored-by: Simon Alibert <alibert.sim@gmail.com>
2024-12-17 02:43:10 +07:00
Michel Aractingi
c3bc136420 Refactor OpenX (#505) 2024-12-17 02:42:59 +07:00
Eugene Mironov
1020bc3108 Fixup 2024-12-17 02:42:53 +07:00
Michel Aractingi
7fcf638c0d Add human intervention mechanism and eval_robot script to evaluate policy on the robot (#541)
Co-authored-by: Yoel <yoel.chornton@gmail.com>
2024-12-17 02:41:31 +07:00
Yoel
e35546f58e Reward classifier and training (#528)
Co-authored-by: Daniel Ritchie <daniel@brainwavecollective.ai>
Co-authored-by: resolver101757 <kelster101757@hotmail.com>
Co-authored-by: Jannik Grothusen <56967823+J4nn1K@users.noreply.github.com>
Co-authored-by: Remi <re.cadene@gmail.com>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2024-12-17 02:41:29 +07:00
Michel Aractingi
1aa8d4ac91 nit 2024-12-17 02:39:15 +07:00
31 changed files with 2200 additions and 784 deletions

View File

@@ -50,7 +50,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
run: pipx install poetry
- name: Poetry check
run: poetry check
@@ -64,7 +64,7 @@ jobs:
uses: actions/checkout@v3
- name: Install poetry
run: pipx install "poetry<2.0.0"
run: pipx install poetry
- name: Install poetry-relax
run: poetry self add poetry-relax

View File

@@ -68,7 +68,7 @@
### Acknowledgment
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.

View File

@@ -21,7 +21,7 @@ How to decode videos?
## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`

View File

@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
sed gawk grep curl wget zip unzip \
tcpdump sysstat screen tmux \
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
speech-dispatcher portaudio19-dev \
speech-dispatcher \
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
@@ -58,7 +58,7 @@ RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
RUN ln -s /usr/bin/python3 /usr/bin/python
# Install poetry
RUN curl -sSL https://install.python-poetry.org | python - --version 1.8.5
RUN curl -sSL https://install.python-poetry.org | python -
ENV PATH="/root/.local/bin:$PATH"
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
RUN poetry config virtualenvs.create false

View File

@@ -1,31 +1,25 @@
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
## A. Source the parts
## Source the parts
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
## B. Install LeRobot
## Install LeRobot
On your computer:
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
```bash
mkdir -p ~/miniconda3
# Linux:
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
# Mac M-series:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
# Mac Intel:
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
rm ~/miniconda3/miniconda.sh
~/miniconda3/bin/conda init bash
```
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
2. Restart shell or `source ~/.bashrc`
3. Create and activate a fresh conda environment for lerobot
```bash
@@ -42,30 +36,23 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
cd ~/lerobot && pip install -e ".[feetech]"
```
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
For Linux only (not Mac), install extra dependencies for recording datasets:
```bash
conda install -y -c conda-forge ffmpeg
pip uninstall -y opencv-python
conda install -y -c conda-forge "opencv>=4.10.0"
```
## C. Configure the motors
## Configure the motors
### 1. Find the USB ports associated to each arm
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
#### a. Run the script to find ports
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
To find the port for each bus servo adapter, run the utility script:
**Find USB ports associated to your arms**
To find the correct ports for each arm, run the utility script twice:
```bash
python lerobot/scripts/find_motors_bus_port.py
```
#### b. Example outputs
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
```
Finding all available ports for the MotorBus.
@@ -77,6 +64,7 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
Reconnect the usb cable.
```
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
```
Finding all available ports for the MotorBus.
@@ -89,20 +77,13 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
Reconnect the usb cable.
```
#### c. Troubleshooting
On Linux, you might need to give access to the USB ports by running:
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
```bash
sudo chmod 666 /dev/ttyACM0
sudo chmod 666 /dev/ttyACM1
```
#### d. Update YAML file
Now that you have the ports, modify the *port* sections in `so100.yaml`
### 2. Configure the motors
#### a. Set IDs for all 12 motors
**Configure your motors**
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
```bash
python lerobot/scripts/configure_motor.py \
@@ -113,7 +94,7 @@ python lerobot/scripts/configure_motor.py \
--ID 1
```
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
Then unplug your motor and plug the second motor and set its ID to 2.
```bash
@@ -127,25 +108,23 @@ python lerobot/scripts/configure_motor.py \
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
**Remove the gears of the 6 leader motors**
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### b. Remove the gears of the 6 leader motors
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
#### c. Add motor horn to all 12 motors
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
**Add motor horn to the motors**
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
## D. Assemble the arms
## Assemble the arms
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
## E. Calibrate
## Calibrate
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
#### a. Manual calibration of follower arm
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
**Manual calibration of follower arm**
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
You will need to move the follower arm to these positions sequentially:
@@ -160,8 +139,8 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_follower
```
#### b. Manual calibration of leader arm
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
**Manual calibration of leader arm**
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
| 1. Zero position | 2. Rotated position | 3. Rest position |
|---|---|---|
@@ -174,7 +153,7 @@ python lerobot/scripts/control_robot.py calibrate \
--robot-overrides '~cameras' --arms main_leader
```
## F. Teleoperate
## Teleoperate
**Simple teleop**
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
@@ -186,14 +165,14 @@ python lerobot/scripts/control_robot.py teleoperate \
```
#### a. Teleop with displaying cameras
**Teleop with displaying cameras**
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
```bash
python lerobot/scripts/control_robot.py teleoperate \
--robot-path lerobot/configs/robot/so100.yaml
```
## G. Record a dataset
## Record a dataset
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
@@ -222,7 +201,7 @@ python lerobot/scripts/control_robot.py record \
--push-to-hub 1
```
## H. Visualize a dataset
## Visualize a dataset
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
```bash
@@ -235,7 +214,7 @@ python lerobot/scripts/visualize_dataset_html.py \
--repo-id ${HF_USER}/so100_test
```
## I. Replay an episode
## Replay an episode
Now try to replay the first episode on your robot:
```bash
@@ -246,7 +225,7 @@ python lerobot/scripts/control_robot.py replay \
--episode 0
```
## J. Train a policy
## Train a policy
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
```bash
@@ -269,7 +248,7 @@ Let's explain it:
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
## K. Evaluate your policy
## Evaluate your policy
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
```bash
@@ -289,7 +268,7 @@ As you can see, it's almost the same command as previously used to record your t
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
## L. More Information
## More
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.

View File

@@ -0,0 +1,83 @@
# Training a HIL-SERL Reward Classifier with LeRobot
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
---
## Training Script Overview
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
1. **Configuration Loading**
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
2. **Dataset Initialization**
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
3. **Classifier Initialization**
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
- A single probability (binary classification), or
- Logits (multi-class classification).
4. **Training Loop Execution**
The script performs:
- Forward and backward passes,
- Optimization steps,
- Periodic logging, evaluation, and checkpoint saving.
---
## Configuring with Hydra
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
### Config File Setup
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
```yaml
# Example: lerobot/configs/policy/reward_classifier.yaml
dataset_repo_id: "my_dataset_repo_id"
## Typical logs and metrics
```
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
After that, you will see training log like this one:
```
[2024-11-29 18:26:36,999][root][INFO] -
Epoch 5/5
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
```
or evaluation log like:
```
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
```
### Metrics Tracking with Weights & Biases (WandB)
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
- **Training Metrics**:
- `train/accuracy`
- `train/loss`
- `train/dataloading_s`
- **Evaluation Metrics**:
- `eval/accuracy`
- `eval/loss`
- `eval/eval_s`
#### Additional Features
You can also log sample predictions during evaluation. Each logged sample will include:
- The **input image**.
- The **predicted label**.
- The **true label**.
- The **classifier's "confidence" (logits/probability)**.
These logs can be useful for diagnosing and debugging performance issues.

View File

@@ -1,213 +0,0 @@
import shutil
from pathlib import Path
import h5py
import numpy as np
import torch
import tqdm
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
def create_empty_dataset(dataset_name, robot_type, mode="video", has_velocity=False, has_effort=False):
motors = [
# TODO(rcadene): verify
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_velocity:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
dataset = LeRobotDataset.create(
repo_id=f"cadene/{dataset_name}_v2",
fps=50,
robot_type=robot_type,
features=features,
)
return dataset
def get_cameras(hdf5_files):
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
# TODO(rcadene): add depth
rgb_cameras = [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
return rgb_cameras
def has_velocity(hdf5_files):
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files):
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep, cameras):
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.imdecode(data, 1))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(ep_path):
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(ep)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(dataset, hdf5_files, task, episodes=None):
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(raw_dir, raw_repo_id, repo_id, episodes: list[int] | None = None, push_to_hub=True):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
raw_dir = Path(raw_dir)
if not raw_dir.exists():
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset_name = repo_id.split("/")[1]
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if "mobile" in dataset_name else "aloha",
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
)
dataset = populate_dataset(
dataset,
hdf5_files,
task="DEBUG",
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
raw_repo_id = "lerobot-raw/aloha_sim_insertion_human_raw"
repo_id = "cadene/aloha_sim_insertion_human_v2"
port_aloha(f"data/{raw_repo_id}", raw_repo_id, repo_id, episodes=[0, 1], push_to_hub=False)

View File

@@ -291,7 +291,7 @@ class LeRobotDatasetMetadata:
obj.root.mkdir(parents=True, exist_ok=False)
if robot is not None:
features = get_features_from_robot(robot, use_videos)
features = {**(features or {}), **get_features_from_robot(robot)}
robot_type = robot.robot_type
if not all(cam.fps == fps for cam in robot.cameras.values()):
logging.warning(

View File

@@ -17,11 +17,9 @@ import importlib.resources
import json
import logging
import textwrap
from collections.abc import Iterator
from itertools import accumulate
from pathlib import Path
from pprint import pformat
from types import SimpleNamespace
from typing import Any
import datasets
@@ -479,6 +477,7 @@ def create_lerobot_dataset_card(
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
"""
card_tags = ["LeRobot"]
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
if tags:
card_tags += tags
@@ -498,65 +497,8 @@ def create_lerobot_dataset_card(
],
)
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
return DatasetCard.from_template(
card_data=card_data,
template_str=card_template,
template_path=str(card_template_path),
**kwargs,
)
class IterableNamespace(SimpleNamespace):
"""
A namespace object that supports both dictionary-like iteration and dot notation access.
Automatically converts nested dictionaries into IterableNamespaces.
This class extends SimpleNamespace to provide:
- Dictionary-style iteration over keys
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
- Dictionary-like methods: items(), keys(), values()
- Recursive conversion of nested dictionaries
Args:
dictionary: Optional dictionary to initialize the namespace
**kwargs: Additional keyword arguments passed to SimpleNamespace
Examples:
>>> data = {"name": "Alice", "details": {"age": 25}}
>>> ns = IterableNamespace(data)
>>> ns.name
'Alice'
>>> ns.details.age
25
>>> list(ns.keys())
['name', 'details']
>>> for key, value in ns.items():
... print(f"{key}: {value}")
name: Alice
details: IterableNamespace(age=25)
"""
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
super().__init__(**kwargs)
if dictionary is not None:
for key, value in dictionary.items():
if isinstance(value, dict):
setattr(self, key, IterableNamespace(value))
else:
setattr(self, key, value)
def __iter__(self) -> Iterator[str]:
return iter(vars(self))
def __getitem__(self, key: str) -> Any:
return vars(self)[key]
def items(self):
return vars(self).items()
def values(self):
return vars(self).values()
def keys(self):
return vars(self).keys()

View File

@@ -159,11 +159,11 @@ DATASETS = {
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup": {
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_vinh_cup_left": {
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
**ALOHA_STATIC_INFO,
},
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},

View File

@@ -25,6 +25,7 @@ from glob import glob
from pathlib import Path
import torch
import wandb
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
@@ -107,8 +108,6 @@ class Logger:
self._wandb = None
else:
os.environ["WANDB_SILENT"] = "true"
import wandb
wandb_run_id = None
if cfg.resume:
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
@@ -232,7 +231,7 @@ class Logger:
# TODO(alexander-soare): Add local text log.
if self._wandb is not None:
for k, v in d.items():
if not isinstance(v, (int, float, str)):
if not isinstance(v, (int, float, str, wandb.Table)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
)

View File

@@ -0,0 +1,36 @@
import json
import os
from dataclasses import asdict, dataclass
import torch
@dataclass
class ClassifierConfig:
"""Configuration for the Classifier model."""
num_classes: int = 2
hidden_dim: int = 256
dropout_rate: float = 0.1
model_name: str = "microsoft/resnet-50"
device: str = "cuda" if torch.cuda.is_available() else "mps"
model_type: str = "cnn" # "transformer" or "cnn"
def save_pretrained(self, save_dir):
"""Save config to json file."""
os.makedirs(save_dir, exist_ok=True)
# Convert to dict and save as JSON
config_dict = asdict(self)
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path):
"""Load config from json file."""
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
with open(config_file) as f:
config_dict = json.load(f)
return cls(**config_dict)

View File

@@ -0,0 +1,134 @@
import logging
from typing import Optional
import torch
from huggingface_hub import PyTorchModelHubMixin
from torch import Tensor, nn
from transformers import AutoImageProcessor, AutoModel
from .configuration_classifier import ClassifierConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
class ClassifierOutput:
"""Wrapper for classifier outputs with additional metadata."""
def __init__(
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
):
self.logits = logits
self.probabilities = probabilities
self.hidden_states = hidden_states
class Classifier(
nn.Module,
PyTorchModelHubMixin,
# Add Hub metadata
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "vision-classifier"],
):
"""Image classifier built on top of a pre-trained encoder."""
# Add name attribute for factory
name = "classifier"
def __init__(self, config: ClassifierConfig):
super().__init__()
self.config = config
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
# Extract vision model if we're given a multimodal model
if hasattr(encoder, "vision_model"):
logging.info("Multimodal model detected - using vision encoder only")
self.encoder = encoder.vision_model
self.vision_config = encoder.config.vision_config
else:
self.encoder = encoder
self.vision_config = getattr(encoder, "config", None)
# Model type from config
self.is_cnn = self.config.model_type == "cnn"
# For CNNs, initialize backbone
if self.is_cnn:
self._setup_cnn_backbone()
self._freeze_encoder()
self._build_classifier_head()
def _setup_cnn_backbone(self):
"""Set up CNN encoder"""
if hasattr(self.encoder, "fc"):
self.feature_dim = self.encoder.fc.in_features
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
elif hasattr(self.encoder.config, "hidden_sizes"):
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
else:
raise ValueError("Unsupported CNN architecture")
def _freeze_encoder(self) -> None:
"""Freeze the encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = False
def _build_classifier_head(self) -> None:
"""Initialize the classifier head architecture."""
# Get input dimension based on model type
if self.is_cnn:
input_dim = self.feature_dim
else: # Transformer models
if hasattr(self.encoder.config, "hidden_size"):
input_dim = self.encoder.config.hidden_size
else:
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
self.classifier_head = nn.Sequential(
nn.Linear(input_dim, self.config.hidden_dim),
nn.Dropout(self.config.dropout_rate),
nn.LayerNorm(self.config.hidden_dim),
nn.ReLU(),
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
)
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
"""Extract the appropriate output from the encoder."""
# Process images with the processor (handles resizing and normalization)
processed = self.processor(
images=x, # LeRobotDataset already provides proper tensor format
return_tensors="pt",
)
processed = processed["pixel_values"].to(x.device)
with torch.no_grad():
if self.is_cnn:
# The HF ResNet applies pooling internally
outputs = self.encoder(processed)
# Get pooled output directly
features = outputs.pooler_output
if features.dim() > 2:
features = features.squeeze(-1).squeeze(-1)
return features
else: # Transformer models
outputs = self.encoder(processed)
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
return outputs.pooler_output
return outputs.last_hidden_state[:, 0, :]
def forward(self, x: torch.Tensor) -> ClassifierOutput:
"""Forward pass of the classifier."""
# For training, we expect input to be a tensor directly from LeRobotDataset
encoder_output = self._get_encoder_output(x)
logits = self.classifier_head(encoder_output)
if self.config.num_classes == 2:
logits = logits.squeeze(-1)
probabilities = torch.sigmoid(logits)
else:
probabilities = torch.softmax(logits, dim=-1)
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)

View File

@@ -0,0 +1,23 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class HILSerlConfig:
pass

View File

@@ -0,0 +1,29 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
class HILSerlPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "hilserl"],
):
pass

View File

@@ -0,0 +1,39 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
@dataclass
class SACConfig:
discount = 0.99
temperature_init = 1.0
num_critics = 2
critic_lr = 3e-4
actor_lr = 3e-4
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "uniform",
}

View File

@@ -0,0 +1,683 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: (1) better device management
from collections import deque
from copy import deepcopy
from functools import partial
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from huggingface_hub import PyTorchModelHubMixin
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.sac.configuration_sac import SACConfig
import numpy as np
from typing import Callable, Optional, Tuple, Sequence
class SACPolicy(
nn.Module,
PyTorchModelHubMixin,
library_name="lerobot",
repo_url="https://github.com/huggingface/lerobot",
tags=["robotics", "RL", "SAC"],
):
def __init__(
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
):
super().__init__()
if config is None:
config = SACConfig()
self.config = config
if config.input_normalization_modes is not None:
self.normalize_inputs = Normalize(
config.input_shapes, config.input_normalization_modes, dataset_stats
)
else:
self.normalize_inputs = nn.Identity()
self.normalize_targets = Normalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
self.unnormalize_outputs = Unnormalize(
config.output_shapes, config.output_normalization_modes, dataset_stats
)
encoder = SACObservationEncoder(config)
# Define networks
critic_nets = []
for _ in range(config.num_critics):
critic_net = Critic(
encoder=encoder,
network=MLP(**config.critic_network_kwargs)
)
critic_nets.append(critic_net)
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
self.critic_target = deepcopy(self.critic_ensemble)
self.actor_network = Policy(
encoder=encoder,
network=MLP(**config.actor_network_kwargs),
action_dim=config.output_shapes["action"][0],
**config.policy_kwargs
)
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
def reset(self):
"""
Clear observation and action queues. Should be called on `env.reset()`
queues are populated during rollout of the policy, they contain the n latest observations and actions
"""
self._queues = {
"observation.state": deque(maxlen=1),
"action": deque(maxlen=1),
}
if self._use_image:
self._queues["observation.image"] = deque(maxlen=1)
if self._use_env_state:
self._queues["observation.environment_state"] = deque(maxlen=1)
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
actions, _ = self.actor_network(batch['observations'])###
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
"""Run the batch through the model and compute the loss.
Returns a dictionary with loss as a tensor, and other information as native floats.
"""
batch = self.normalize_inputs(batch)
# batch shape is (b, 2, ...) where index 1 returns the current observation and
# the next observation for caluculating the right td index.
actions = batch["action"][:, 0]
rewards = batch["next.reward"][:, 0]
observations = {}
next_observations = {}
for k in batch:
if k.startswith("observation."):
observations[k] = batch[k][:, 0]
next_observations[k] = batch[k][:, 1]
# perform image augmentation
# reward bias
# from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
# calculate critics loss
# 1- compute actions from policy
action_preds, log_probs = self.actor_network(observations)
# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# critics subsample size
min_q = q_targets.min(dim=0)
# backup entropy
td_target = rewards + self.discount * min_q
# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
critics_loss = (
F.mse_loss(
q_preds,
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
reduction="none",
).sum(0) # sum over ensemble
# `q_preds_ensemble` depends on the first observation and the actions.
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
# q_targets depends on the reward and the next observations.
* ~batch["next.reward_is_pad"]
* ~batch["observation.state_is_pad"][1:]
).sum(0).mean()
# calculate actors loss
# 1- temperature
temperature = self.temperature()
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
actions, log_probs = self.actor_network(observations) \
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
actor_loss = (
-(q_preds - temperature * log_probs).mean()
* ~batch["observation.state_is_pad"][0]
* ~batch["action_is_pad"]
).mean()
# calculate temperature loss
# 1- calculate entropy
entropy = -log_probs.mean()
temperature_loss = temperature * (entropy - self.target_entropy).mean()
loss = critics_loss + actor_loss + temperature_loss
return {
"critics_loss": critics_loss.item(),
"actor_loss": actor_loss.item(),
"temperature_loss": temperature_loss.item(),
"temperature": temperature.item(),
"entropy": entropy.item(),
"loss": loss,
}
def update(self):
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
class MLP(nn.Module):
def __init__(
self,
config: SACConfig,
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
activate_final: bool = False,
dropout_rate: Optional[float] = None,
):
super().__init__()
self.activate_final = config.activate_final
layers = []
for i, size in enumerate(config.network_hidden_dims):
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
if i + 1 < len(config.network_hidden_dims) or activate_final:
if dropout_rate is not None and dropout_rate > 0:
layers.append(nn.Dropout(p=dropout_rate))
layers.append(nn.LayerNorm(size))
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
# in training mode or not. TODO: find better way to do this
self.train(train)
return self.net(x)
class Critic(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
self.activate_final = activate_final
# Output layer
if init_final is not None:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
else:
if self.activate_final:
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
else:
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
orthogonal_init()(self.output_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
self.train(train)
observations = observations.to(self.device)
actions = actions.to(self.device)
if self.encoder is not None:
obs_enc = self.encoder(observations)
else:
obs_enc = observations
inputs = torch.cat([obs_enc, actions], dim=-1)
x = self.network(inputs)
value = self.output_layer(x)
return value.squeeze(-1)
def q_value_ensemble(
self,
observations: torch.Tensor,
actions: torch.Tensor,
train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)
class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 1e-5,
std_max: float = 10.0,
tanh_squash_distribution: bool = False,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.std_parameterization = std_parameterization
self.std_min = std_min
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
# Mean layer
if self.activate_final:
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.mean_layer.weight)
# Standard deviation layer or parameter
if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
else:
if self.activate_final:
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
else:
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)
self.to(self.device)
def forward(
self,
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False
) -> torch.distributions.Distribution:
self.train(train)
# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
obs_enc = self.encoder(observations, train=train)
else:
obs_enc = observations
# Get network outputs
outputs = self.network(obs_enc)
means = self.mean_layer(outputs)
# Compute standard deviations
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
elif self.std_parameterization == "uniform":
stds = torch.exp(self.log_stds).expand_as(means)
else:
raise ValueError(
f"Invalid std_parameterization: {self.std_parameterization}"
)
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)
# Clip standard deviations and scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
# Create distribution
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)
return distribution
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
if self.encoder is not None:
with torch.no_grad():
return self.encoder(observations, train=False)
return observations
class SACObservationEncoder(nn.Module):
"""Encode image and/or state vector observations.
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
"""
def __init__(self, config: SACConfig):
"""
Creates encoders for pixel and/or state modalities.
"""
super().__init__()
self.config = config
if "observation.image" in config.input_shapes:
self.image_enc_layers = nn.Sequential(
nn.Conv2d(
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
nn.ReLU(),
)
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
with torch.inference_mode():
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
self.image_enc_layers.extend(
nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(out_shape), config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
)
if "observation.state" in config.input_shapes:
self.state_enc_layers = nn.Sequential(
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
if "observation.environment_state" in config.input_shapes:
self.env_state_enc_layers = nn.Sequential(
nn.Linear(
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
),
nn.ELU(),
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
nn.LayerNorm(config.latent_dim),
nn.Tanh(),
)
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
"""Encode the image and/or state vector.
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
over all features.
"""
feat = []
# Concatenate all images along the channel dimension.
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
for image_key in image_keys:
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
if "observation.environment_state" in self.config.input_shapes:
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
if "observation.state" in self.config.input_shapes:
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
return torch.stack(feat, dim=0).mean(0)
class LagrangeMultiplier(nn.Module):
def __init__(
self,
init_value: float = 1.0,
constraint_shape: Sequence[int] = (),
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
# Initialize the Lagrange multiplier as a parameter
self.lagrange = nn.Parameter(
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
)
self.to(self.device)
def forward(
self,
lhs: Optional[torch.Tensor] = None,
rhs: Optional[torch.Tensor] = None
) -> torch.Tensor:
# Get the multiplier value based on parameterization
multiplier = torch.nn.functional.softplus(self.lagrange)
# Return the raw multiplier if no constraint values provided
if lhs is None:
return multiplier
# Move inputs to device
lhs = lhs.to(self.device)
if rhs is not None:
rhs = rhs.to(self.device)
# Use the multiplier to compute the Lagrange penalty
if rhs is None:
rhs = torch.zeros_like(lhs, device=self.device)
diff = lhs - rhs
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
return multiplier * diff
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
# Create list of transforms
transforms = []
# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())
# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(
loc=(high + low) / 2,
scale=(high - low) / 2
)
)
# Initialize parent class
super().__init__(
base_distribution=base_distribution,
transforms=transforms
)
# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high
def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)
return mode
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distribution
x = self.base_dist.rsample(sample_shape)
# Apply transforms
for transform in self.transforms:
x = transform(x)
return x
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value[..., 0])
# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q = transform.inv(q)
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
return log_prob
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob
def entropy(self) -> torch.Tensor:
"""
Compute entropy of the distribution
"""
# Start with base distribution entropy
entropy = self.base_dist.entropy().sum(-1)
# Add log det jacobian for each transform
x = self.rsample()
for transform in self.transforms:
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
x = transform(x)
return entropy
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
"""Creates an ensemble of critic networks"""
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
return critics.to(device)
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
# borrowed from tdmpc
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
"""Helper to temporarily flatten extra dims at the start of the image tensor.
Args:
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
(B, *), where * is any number of dimensions.
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
can be more than 1 dimensions, generally different from *.
Returns:
A return value from the callable reshaped to (**, *).
"""
if image_tensor.ndim == 4:
return fn(image_tensor)
start_dims = image_tensor.shape[:-3]
inp = torch.flatten(image_tensor, end_dim=-4)
flat_out = fn(inp)
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))

View File

@@ -120,14 +120,22 @@ def predict_action(observation, policy, device, use_amp):
return action
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
def init_keyboard_listener(assign_rewards=False):
"""
Initializes a keyboard listener to enable early termination of an episode
or environment reset by pressing the right arrow key ('->'). This may require
sudo permissions to allow the terminal to monitor keyboard events.
Args:
assign_rewards (bool): If True, allows annotating the collected trajectory
with a binary reward at the end of the episode to indicate success.
"""
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if assign_rewards:
events["next.reward"] = 0
if is_headless():
logging.warning(
@@ -152,6 +160,13 @@ def init_keyboard_listener():
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif assign_rewards and key == keyboard.Key.space:
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
print(
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
events["next.reward"],
)
except Exception as e:
print(f"Error handling key press: {e}")
@@ -272,6 +287,8 @@ def control_loop(
if dataset is not None:
frame = {**observation, **action}
if "next.reward" in events:
frame["next.reward"] = events["next.reward"]
dataset.add_frame(frame)
if display_cameras and not is_headless():
@@ -301,6 +318,8 @@ def reset_environment(robot, events, reset_time_s):
timestamp = 0
start_vencod_t = time.perf_counter()
if "next.reward" in events:
events["next.reward"] = 0
# Wait if necessary
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:

View File

@@ -0,0 +1,48 @@
# @package _global_
defaults:
- _self_
seed: 13
dataset_repo_id: "dataset_repo_id"
train_split_proportion: 0.8
# Required by logger
env:
name: "classifier"
task: "binary_classification"
training:
num_epochs: 5
batch_size: 16
learning_rate: 1e-4
num_workers: 4
grad_clip_norm: 10
use_amp: true
log_freq: 1
eval_freq: 1 # How often to run validation (in epochs)
save_freq: 1 # How often to save checkpoints (in epochs)
save_checkpoint: true
image_key: "observation.images.phone"
label_key: "next.reward"
eval:
batch_size: 16
num_samples_to_log: 30 # Number of validation samples to log in the table
policy:
name: "hilserl/classifier"
model_name: "facebook/convnext-base-224"
model_type: "cnn"
wandb:
enable: false
project: "classifier-training"
entity: "wandb_entity"
job_name: "classifier_training_0"
disable_artifact: false
device: "mps"
resume: false
output_dir: "output"

View File

@@ -10,7 +10,7 @@ max_relative_target: null
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0031751
port: /dev/tty.usbmodem58760430441
motors:
# name: (index, model)
shoulder_pan: [1, "xl330-m077"]
@@ -23,7 +23,7 @@ leader_arms:
follower_arms:
main:
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
port: /dev/tty.usbmodem575E0032081
port: /dev/tty.usbmodem585A0083391
motors:
# name: (index, model)
shoulder_pan: [1, "xl430-w250"]

View File

@@ -18,7 +18,7 @@ max_relative_target: null
leader_arms:
main:
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
port: /dev/tty.usbmodem585A0077581
port: /dev/tty.usbmodem58760433331
motors:
# name: (index, model)
shoulder_pan: [1, "sts3215"]

View File

@@ -191,6 +191,7 @@ def record(
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
assign_rewards: bool = False,
fps: int | None = None,
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
@@ -214,6 +215,9 @@ def record(
policy = None
device = None
use_amp = None
extra_features = (
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
)
if single_task:
task = single_task
@@ -254,12 +258,12 @@ def record(
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
features=extra_features,
)
if not robot.is_connected:
robot.connect()
listener, events = init_keyboard_listener()
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
@@ -300,7 +304,7 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
@@ -469,12 +473,12 @@ if __name__ == "__main__":
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--tags",
type=str,
nargs="*",
help="Add tags to your dataset on the hub.",
)
# parser_record.add_argument(
# "--tags",
# type=str,
# nargs="*",
# help="Add tags to your dataset on the hub.",
# )
parser_record.add_argument(
"--num-image-writer-processes",
type=int,
@@ -517,6 +521,12 @@ if __name__ == "__main__":
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_record.add_argument(
"--assign-rewards",
type=int,
default=0,
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(

View File

@@ -0,0 +1,335 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
```
python lerobot/scripts/eval_on_robot.py \
-p outputs/train/model/checkpoints/005000/pretrained_model \
eval.n_episodes=10
```
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
for running training on the real robot.
"""
import argparse
import logging
import time
from copy import deepcopy
import numpy as np
import torch
from tqdm import trange
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
from lerobot.common.utils.utils import (
init_hydra_config,
init_logging,
log_say,
)
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
"""Run a batched policy rollout on the real robot.
The return dictionary contains:
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
keys. NOTE the that this has an extra sequence element relative to the other keys in the
dictionary. This is because an extra observation is included for after the environment is
terminated or truncated.
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
including the last observations).
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
environment termination/truncation).
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
the first True is followed by True's all the way till the end. This can be used for masking
extraneous elements from the sequences above.
Args:
robot: The robot class that defines the interface with the real robot.
policy: The policy. Must be a PyTorch nn module.
Returns:
The dictionary described above.
"""
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
# device = get_device_from_parameters(policy)
# define keyboard listener
listener, events = init_keyboard_listener()
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
# policy.reset()
# Get observation from real robot
observation = robot.capture_observation()
# Calculate reward. TODO (michel-aractingi)
# in HIL-SERL it will be with a reward classifier
reward = calculate_reward(observation)
all_observations = []
all_actions = []
all_rewards = []
all_successes = []
start_episode_t = time.perf_counter()
timestamp = 0.0
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
all_observations.append(deepcopy(observation))
# observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
# Apply the next action.
while events["pause_policy"] and not events["human_intervention_step"]:
busy_wait(0.5)
if events["human_intervention_step"]:
# take over the robot's actions
observation, action = robot.teleop_step(record_data=True)
action = action["action"] # teleop step returns torch tensors but in a dict
else:
# explore with policy
with torch.inference_mode():
action = robot.follower_arms["main"].read("Present_Position")
action = torch.from_numpy(action)
robot.send_action(action)
# action = predict_action(observation, policy, device, use_amp)
observation = robot.capture_observation()
# Calculate reward
# in HIL-SERL it will be with a reward classifier
reward = calculate_reward(observation)
all_actions.append(action)
all_rewards.append(torch.from_numpy(reward))
all_successes.append(torch.tensor([False]))
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
events["human_intervention_step"] = False
events["pause_policy"] = False
break
all_observations.append(deepcopy(observation))
dones = torch.tensor([False] * len(all_actions))
dones[-1] = True
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
ret = {
"action": torch.stack(all_actions, dim=1),
"next.reward": torch.stack(all_rewards, dim=1),
"next.success": torch.stack(all_successes, dim=1),
"done": dones,
}
stacked_observations = {}
for key in all_observations[0]:
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret["observation"] = stacked_observations
listener.stop()
return ret
def eval_policy(
robot: Robot,
policy: torch.nn.Module,
fps: float,
n_episodes: int,
control_time_s: int = 20,
use_amp: bool = True,
) -> dict:
"""
Args:
env: The batch of environments.
policy: The policy.
n_episodes: The number of episodes to evaluate.
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
# TODO (michel-aractingi) comment this out for testing with a fixed policy
# assert isinstance(policy, Policy)
# policy.eval()
sum_rewards = []
max_rewards = []
successes = []
rollouts = []
start_eval = time.perf_counter()
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
for _batch_idx in progbar:
rollout_data = rollout(robot, policy, fps, control_time_s, use_amp)
rollouts.append(rollout_data)
sum_rewards.append(sum(rollout_data["next.reward"]))
max_rewards.append(max(rollout_data["next.reward"]))
successes.append(rollout_data["next.success"][-1])
info = {
"per_episode": [
{
"episode_ix": i,
"sum_reward": sum_reward,
"max_reward": max_reward,
"pc_success": success * 100,
}
for i, (sum_reward, max_reward, success) in enumerate(
zip(
sum_rewards[:n_episodes],
max_rewards[:n_episodes],
successes[:n_episodes],
strict=False,
)
)
],
"aggregated": {
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
"eval_s": time.time() - start_eval,
"eval_ep_s": (time.time() - start_eval) / n_episodes,
},
}
if robot.is_connected:
robot.disconnect()
return info
def calculate_reward(observation):
"""
Method to calculate reward function in some way.
In HIL-SERL this is done through defining a reward classifier
"""
# reward = reward_classifier(observation)
return np.array([0.0])
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["rerecord_episode"] = False
events["pause_policy"] = False
events["human_intervention_step"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.space:
# check if first space press then pause the policy for the user to get ready
# if second space press then the user is ready to start intervention
if not events["pause_policy"]:
print(
"Space key pressed. Human intervention required.\n"
"Place the leader in similar pose to the follower and press space again."
)
events["pause_policy"] = True
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
else:
events["human_intervention_step"] = True
print("Space key pressed. Human intervention starting.")
log_say("Starting human intervention.", play_sounds=True)
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
group.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
args = parser.parse_args()
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
robot = make_robot(robot_cfg)
if not robot.is_connected:
robot.connect()
eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100)

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from contextlib import nullcontext
from pathlib import Path
from pprint import pformat
import hydra
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import optim
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from tqdm import tqdm
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
set_global_seed,
)
def get_model(cfg, logger):
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
model = Classifier(classifier_config)
if cfg.resume:
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
return model
def create_balanced_sampler(dataset, cfg):
# Creates a weighted sampler to handle class imbalance
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
_, counts = torch.unique(labels, return_counts=True)
class_weights = 1.0 / counts.float()
sample_weights = class_weights[labels]
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
# Single epoch training loop with AMP support and progress tracking
model.train()
correct = 0
total = 0
pbar = tqdm(train_loader, desc="Training")
for batch_idx, batch in enumerate(pbar):
start_time = time.perf_counter()
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
# Forward pass with optional AMP
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Backward pass with gradient scaling if AMP enabled
optimizer.zero_grad()
if cfg.training.use_amp:
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
else:
loss.backward()
optimizer.step()
# Track metrics
if model.config.num_classes == 2:
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
current_acc = 100 * correct / total
train_info = {
"loss": loss.item(),
"accuracy": current_acc,
"dataloading_s": time.perf_counter() - start_time,
}
logger.log_dict(train_info, step + batch_idx, mode="train")
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
# Validation loop with metric tracking and sample logging
model.eval()
correct = 0
total = 0
batch_start_time = time.perf_counter()
samples = []
running_loss = 0
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
for batch in tqdm(val_loader, desc="Validation"):
images = batch[cfg.training.image_key].to(device)
labels = batch[cfg.training.label_key].float().to(device)
outputs = model(images)
loss = criterion(outputs.logits, labels)
# Track metrics
if model.config.num_classes == 2:
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
else:
predictions = torch.argmax(outputs.logits, dim=1)
correct += (predictions == labels).sum().item()
total += labels.size(0)
running_loss += loss.item()
# Log sample predictions for visualization
if len(samples) < num_samples_to_log:
for i in range(min(num_samples_to_log - len(samples), len(images))):
if model.config.num_classes == 2:
confidence = round(outputs.probabilities[i].item(), 3)
else:
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
samples.append(
{
"image": wandb.Image(images[i].cpu()),
"true_label": labels[i].item(),
"predicted": predictions[i].item(),
"confidence": confidence,
}
)
accuracy = 100 * correct / total
avg_loss = running_loss / len(val_loader)
eval_info = {
"loss": avg_loss,
"accuracy": accuracy,
"eval_s": time.perf_counter() - batch_start_time,
"eval/prediction_samples": wandb.Table(
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
columns=["Image", "True Label", "Predicted", "Confidence"],
)
if logger._cfg.wandb.enable
else None,
}
return accuracy, eval_info
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
logging.info(OmegaConf.to_yaml(cfg))
# Initialize training environment
device = get_safe_torch_device(cfg.device, log=True)
set_global_seed(cfg.seed)
out_dir = Path(cfg.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
# Setup dataset and dataloaders
dataset = LeRobotDataset(cfg.dataset_repo_id)
logging.info(f"Dataset size: {len(dataset)}")
train_size = int(cfg.train_split_proportion * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
sampler = create_balanced_sampler(train_dataset, cfg)
train_loader = DataLoader(
train_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
sampler=sampler,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=cfg.eval.batch_size,
shuffle=False,
num_workers=cfg.training.num_workers,
pin_memory=True,
)
# Resume training if requested
step = 0
best_val_acc = 0
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Load and validate checkpoint configuration
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
# Initialize model and training components
model = get_model(cfg=cfg, logger=logger).to(device)
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
# Log model parameters
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in model.parameters())
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
if cfg.resume:
step = logger.load_last_training_state(optimizer, None)
# Training loop with validation and checkpointing
for epoch in range(cfg.training.num_epochs):
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
# Periodic validation
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
val_acc, eval_info = validate(
model,
val_loader,
criterion,
device,
logger,
cfg,
)
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
# Save best model
if val_acc > best_val_acc:
best_val_acc = val_acc
logger.save_checkpoint(
train_step=step + len(train_loader),
policy=model,
optimizer=optimizer,
scheduler=None,
identifier="best",
)
# Periodic checkpointing
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
logger.save_checkpoint(
train_step=step + len(train_loader),
policy=model,
optimizer=optimizer,
scheduler=None,
identifier=f"{epoch+1:06d}",
)
step += len(train_loader)
logging.info("Training completed")
if __name__ == "__main__":
train()

View File

@@ -53,29 +53,20 @@ python lerobot/scripts/visualize_dataset_html.py \
"""
import argparse
import csv
import json
import logging
import re
import shutil
import tempfile
from io import StringIO
from pathlib import Path
import numpy as np
import pandas as pd
import requests
from flask import Flask, redirect, render_template, request, url_for
import tqdm
from flask import Flask, redirect, render_template, url_for
from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import IterableNamespace
from lerobot.common.utils.utils import init_logging
def run_server(
dataset: LeRobotDataset | IterableNamespace | None,
episodes: list[int] | None,
dataset: LeRobotDataset,
episodes: list[int],
host: str,
port: str,
static_folder: Path,
@@ -85,50 +76,10 @@ def run_server(
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
@app.route("/")
def hommepage(dataset=dataset):
if dataset:
dataset_namespace, dataset_name = dataset.repo_id.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=0,
)
)
dataset_param, episode_param = None, None
all_params = request.args
if "dataset" in all_params:
dataset_param = all_params["dataset"]
if "episode" in all_params:
episode_param = int(all_params["episode"])
if dataset_param:
dataset_namespace, dataset_name = dataset_param.split("/")
return redirect(
url_for(
"show_episode",
dataset_namespace=dataset_namespace,
dataset_name=dataset_name,
episode_id=episode_param if episode_param is not None else 0,
)
)
featured_datasets = [
"lerobot/aloha_static_cups_open",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/taco_play",
]
return render_template(
"visualize_dataset_homepage.html",
featured_datasets=featured_datasets,
lerobot_datasets=available_datasets,
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
def show_first_episode(dataset_namespace, dataset_name):
first_episode_id = 0
def index():
# home page redirects to the first episode page
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
first_episode_id = episodes[0]
return redirect(
url_for(
"show_episode",
@@ -139,85 +90,30 @@ def run_server(
)
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
repo_id = f"{dataset_namespace}/{dataset_name}"
try:
if dataset is None:
dataset = get_dataset_info(repo_id)
except FileNotFoundError:
return (
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
400,
)
dataset_version = (
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
)
match = re.search(r"v(\d+)\.", dataset_version)
if match:
major_version = int(match.group(1))
if major_version < 2:
return "Make sure to convert your LeRobotDataset to v2 & above."
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
def show_episode(dataset_namespace, dataset_name, episode_id):
dataset_info = {
"repo_id": f"{dataset_namespace}/{dataset_name}",
"num_samples": dataset.num_frames
if isinstance(dataset, LeRobotDataset)
else dataset.total_frames,
"num_episodes": dataset.num_episodes
if isinstance(dataset, LeRobotDataset)
else dataset.total_episodes,
"repo_id": dataset.repo_id,
"num_samples": dataset.num_frames,
"num_episodes": dataset.num_episodes,
"fps": dataset.fps,
}
if isinstance(dataset, LeRobotDataset):
video_paths = [
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
for video_path in video_paths
]
tasks = dataset.meta.episodes[episode_id]["tasks"]
else:
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
videos_info = [
{
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
+ dataset.video_path.format(
episode_chunk=int(episode_id) // dataset.chunks_size,
video_key=video_key,
episode_index=episode_id,
),
"filename": video_key,
}
for video_key in video_keys
]
response = requests.get(
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
)
response.raise_for_status()
# Split into lines and parse each line as JSON
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
tasks = filtered_tasks_jsonl[0]["tasks"]
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
tasks = dataset.meta.episodes[episode_id]["tasks"]
videos_info = [
{"url": url_for("static", filename=video_path), "filename": video_path.name}
for video_path in video_paths
]
videos_info[0]["language_instruction"] = tasks
if episodes is None:
episodes = list(
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
)
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
return render_template(
"visualize_dataset_template.html",
episode_id=episode_id,
episodes=episodes,
dataset_info=dataset_info,
videos_info=videos_info,
episode_data_csv_str=episode_data_csv_str,
columns=columns,
ep_csv_url=ep_csv_url,
has_policy=False,
)
app.run(host=host, port=port)
@@ -228,69 +124,46 @@ def get_ep_csv_fname(episode_id: int):
return ep_csv_fname
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
This file will be loaded by Dygraph javascript to plot data in real time."""
columns = []
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
selected_columns.remove("timestamp")
has_state = "observation.state" in dataset.features
has_action = "action" in dataset.features
# init header of csv with state and action names
header = ["timestamp"]
if has_state:
dim_state = dataset.meta.shapes["observation.state"][0]
header += [f"state_{i}" for i in range(dim_state)]
if has_action:
dim_action = dataset.meta.shapes["action"][0]
header += [f"action_{i}" for i in range(dim_action)]
for column_name in selected_columns:
dim_state = (
dataset.meta.shapes[column_name][0]
if isinstance(dataset, LeRobotDataset)
else dataset.features[column_name].shape[0]
)
header += [f"{column_name}_{i}" for i in range(dim_state)]
columns = ["timestamp"]
if has_state:
columns += ["observation.state"]
if has_action:
columns += ["action"]
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
column_names = dataset.features[column_name]["names"]
while not isinstance(column_names, list):
column_names = list(column_names.values())[0]
else:
column_names = [f"motor_{i}" for i in range(dim_state)]
columns.append({"key": column_name, "value": column_names})
rows = []
data = dataset.hf_dataset.select_columns(columns)
for i in range(from_idx, to_idx):
row = [data[i]["timestamp"].item()]
if has_state:
row += data[i]["observation.state"].tolist()
if has_action:
row += data[i]["action"].tolist()
rows.append(row)
selected_columns.insert(0, "timestamp")
if isinstance(dataset, LeRobotDataset):
from_idx = dataset.episode_data_index["from"][episode_index]
to_idx = dataset.episode_data_index["to"][episode_index]
data = (
dataset.hf_dataset.select(range(from_idx, to_idx))
.select_columns(selected_columns)
.with_format("pandas")
)
else:
repo_id = dataset.repo_id
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
)
df = pd.read_parquet(url)
data = df[selected_columns] # Select specific columns
rows = np.hstack(
(
np.expand_dims(data["timestamp"], axis=1),
*[np.vstack(data[col]) for col in selected_columns[1:]],
)
).tolist()
# Convert data to CSV string
csv_buffer = StringIO()
csv_writer = csv.writer(csv_buffer)
# Write header
csv_writer.writerow(header)
# Write data rows
csv_writer.writerows(rows)
csv_string = csv_buffer.getvalue()
return csv_string, columns
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir / file_name, "w") as f:
f.write(",".join(header) + "\n")
for row in rows:
row_str = [str(col) for col in row]
f.write(",".join(row_str) + "\n")
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
@@ -302,31 +175,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
]
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
# check if the dataset has language instructions
if "language_instruction" not in dataset.features:
return None
# get first frame index
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
# with the tf.tensor appearing in the string
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
def get_dataset_info(repo_id: str) -> IterableNamespace:
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
response.raise_for_status() # Raises an HTTPError for bad responses
dataset_info = response.json()
dataset_info["repo_id"] = repo_id
return IterableNamespace(dataset_info)
def visualize_dataset_html(
dataset: LeRobotDataset | None,
episodes: list[int] | None = None,
dataset: LeRobotDataset,
episodes: list[int] = None,
output_dir: Path | None = None,
serve: bool = True,
host: str = "127.0.0.1",
@@ -335,11 +186,11 @@ def visualize_dataset_html(
) -> Path | None:
init_logging()
template_dir = Path(__file__).resolve().parent.parent / "templates"
if len(dataset.meta.image_keys) > 0:
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
if output_dir is None:
# Create a temporary directory that will be automatically cleaned up
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
output_dir = Path(output_dir)
if output_dir.exists():
@@ -350,29 +201,28 @@ def visualize_dataset_html(
output_dir.mkdir(parents=True, exist_ok=True)
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
static_dir = output_dir / "static"
static_dir.mkdir(parents=True, exist_ok=True)
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
if dataset is None:
if serve:
run_server(
dataset=None,
episodes=None,
host=host,
port=port,
static_folder=static_dir,
template_folder=template_dir,
)
else:
# Create a simlink from the dataset video folder containg mp4 files to the output directory
# so that the http server can get access to the mp4 files.
if isinstance(dataset, LeRobotDataset):
ln_videos_dir = static_dir / "videos"
if not ln_videos_dir.exists():
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
template_dir = Path(__file__).resolve().parent.parent / "templates"
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
if episodes is None:
episodes = list(range(dataset.num_episodes))
logging.info("Writing CSV files")
for episode_index in tqdm.tqdm(episodes):
# write states and actions in a csv (it can be slow for big datasets)
ep_csv_fname = get_ep_csv_fname(episode_index)
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
if serve:
run_server(dataset, episodes, host, port, static_dir, template_dir)
def main():
@@ -381,7 +231,7 @@ def main():
parser.add_argument(
"--repo-id",
type=str,
default=None,
required=True,
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
)
parser.add_argument(
@@ -396,12 +246,6 @@ def main():
default=None,
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
)
parser.add_argument(
"--load-from-hf-hub",
type=int,
default=0,
help="Load videos and parquet files from HF Hub rather than local system.",
)
parser.add_argument(
"--episodes",
type=int,
@@ -443,19 +287,11 @@ def main():
args = parser.parse_args()
kwargs = vars(args)
repo_id = kwargs.pop("repo_id")
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
root = kwargs.pop("root")
local_files_only = kwargs.pop("local_files_only")
dataset = None
if repo_id:
dataset = (
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
if not load_from_hf_hub
else get_dataset_info(repo_id)
)
visualize_dataset_html(dataset, **vars(args))
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
visualize_dataset_html(dataset, **kwargs)
if __name__ == "__main__":

View File

@@ -1,68 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Interactive Video Background Page</title>
<script src="https://cdn.tailwindcss.com"></script>
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
</head>
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
inputValue: '',
navigateToDataset() {
const trimmedValue = this.inputValue.trim();
if (trimmedValue) {
window.location.href = `/${trimmedValue}`;
}
}
}">
<div class="fixed inset-0 w-full h-full overflow-hidden">
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
Your browser does not support HTML5 video.
</video>
</div>
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
<div class="text-center mb-8">
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
<p class="text-xl mb-4"></p>
<div class="text-left inline-block">
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
<ul class="list-disc list-inside">
{% for dataset in featured_datasets %}
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
{% endfor %}
</ul>
</div>
</div>
<div class="flex w-full max-w-lg px-4 mb-4">
<input
type="text"
x-model="inputValue"
@keyup.enter="navigateToDataset"
placeholder="enter dataset id (ex: lerobot/droid_100)"
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
>
<button
@click="navigateToDataset"
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
>
Go
</button>
</div>
<details class="mt-4 max-w-full px-4">
<summary>More example datasets</summary>
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
{% for dataset in lerobot_datasets %}
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
{% endfor %}
</ul>
</details>
</div>
</body>
</html>

View File

@@ -31,16 +31,11 @@
}">
<!-- Sidebar -->
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
</a>
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
</a>
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
<ul>
<li>
Number of samples/frames: {{ dataset_info.num_samples }}
Number of samples/frames: {{ dataset_info.num_frames }}
</li>
<li>
Number of episodes: {{ dataset_info.num_episodes }}
@@ -98,35 +93,10 @@
</div>
<!-- Videos -->
<div class="max-w-32 relative text-sm mb-4 select-none"
@click.outside="isVideosDropdownOpen = false">
<div
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
>
<span class="truncate">filter videos</span>
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
</div>
<div x-show="isVideosDropdownOpen"
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
<div>
<template x-for="option in videosKeys" :key="option">
<div
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
class="p-2 cursor-pointer bg-slate-900"
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
x-text="option"
></div>
</template>
</div>
</div>
</div>
<div class="flex flex-wrap gap-x-2 gap-y-6">
<div class="flex flex-wrap gap-1">
{% for video_info in videos_info %}
<div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<div x-show="!videoCodecError" class="max-w-96">
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
if (video.duration) {
const time = video.currentTime;
@@ -212,12 +182,12 @@
<thead>
<tr>
<th></th>
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
<th class="border border-slate-700">
<div class="flex gap-x-2 justify-between px-2">
<input type="checkbox" :checked="isColumnChecked(colIndex)"
@change="toggleColumn(colIndex)">
<p x-text="`${columns[colIndex].key}`"></p>
<p x-text="`${columnNames[colIndex]}`"></p>
</div>
</th>
</template>
@@ -227,10 +197,10 @@
<template x-for="(row, rowIndex) in rows">
<tr class="odd:bg-gray-800 even:bg-gray-900">
<td class="border border-slate-700">
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
<div class="flex gap-x-2 w-24 font-semibold px-1">
<input type="checkbox" :checked="isRowChecked(rowIndex)"
@change="toggleRow(rowIndex)">
<p x-text="`${rowLabels[rowIndex]}`"></p>
<p x-text="`Motor ${rowIndex}`"></p>
</div>
</td>
<template x-for="(cell, colIndex) in row">
@@ -252,20 +222,16 @@
</div>
</div>
<script>
const parentOrigin = "https://huggingface.co";
const searchParams = new URLSearchParams();
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
searchParams.set("episode", "{{ episode_id }}");
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
</script>
<script>
function createAlpineData() {
return {
// state
dygraph: null,
currentFrameData: null,
columnNames: ["state", "action", "pred action"],
nColumns: 2,
nStates: 0,
nActions: 0,
checked: [],
dygraphTime: 0.0,
dygraphIndex: 0,
@@ -275,11 +241,6 @@
nVideos: {{ videos_info | length }},
nVideoReadyToPlay: 0,
videoCodecError: false,
isVideosDropdownOpen: false,
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
videosKeysSelected: [],
columns: {{ columns | tojson }},
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
// alpine initialization
init() {
@@ -289,19 +250,11 @@
if(!canPlayVideos){
this.videoCodecError = true;
}
this.videosKeysSelected = this.videosKeys.map(opt => opt)
// process CSV data
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
// Create a Blob with the CSV data
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
// Create a URL for the Blob
const csvUrl = URL.createObjectURL(blob);
// process CSV data
this.videos = document.querySelectorAll('video');
this.video = this.videos[0];
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
pixelsPerPoint: 0.01,
legend: 'always',
labelsDiv: document.getElementById('labels'),
@@ -322,17 +275,21 @@
this.colors = this.dygraph.getColors();
this.checked = Array(this.colors.length).fill(true);
const seriesNames = this.dygraph.getLabels().slice(1);
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
this.nActions = seriesNames.length - this.nStates;
const colors = [];
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
for(const column of this.columns){
const nValues = column.value.length;
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
const color = `hsl(${hue}, 100%, ${lightness}%)`;
colors.push(color);
}
lightness += 35;
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
// colors for "state" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
colors.push(color);
}
// colors for "action" lines
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
colors.push(color);
}
this.dygraph.updateOptions({ colors });
this.colors = colors;
@@ -359,19 +316,17 @@
return [];
}
const rows = [];
const nRows = Math.max(...this.columns.map(column => column.value.length));
const nRows = Math.max(this.nStates, this.nActions);
let rowIndex = 0;
while(rowIndex < nRows){
const row = [];
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
const nullCell = { isNull: true };
const stateValueIdx = rowIndex;
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
// row consists of [state value, action value]
let idx = rowIndex;
for(const column of this.columns){
const nColumn = column.value.length;
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
}
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
rowIndex += 1;
rows.push(row);
}

28
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
[[package]]
name = "absl-py"
@@ -1294,10 +1294,6 @@ files = [
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:78656d3ae1282a142a5fed410ec3a6f725fdf8d9f9192ed673e336ea3b083e12"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:681e22c8ecb3b48d11cb9019f8a32d4ae1e353e20d4ce3a0f0eedd0ccbd95e5f"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4598572bab6f726ec41fabb43bf0f7e3cf8082ea0f6f8f4e57845a6c919f31b3"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:157fc1fed50946646f09df75c6d52198735a5973e53d252199bbb1c65e1594d2"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_armv7l.whl", hash = "sha256:7ae2724c181be10692c24fb8d9ce2a99a9afc57237332c3658e2ea6f4f33c091"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_i686.whl", hash = "sha256:3d324835f292edd81b962f8c0df44f7f47c0a6f8fe6f7d081951aeb1f5ba57d2"},
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:474c087b5e584293685a7d4837165b2ead96dc74fb435ae50d5fa0ac168a0de0"},
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:297350f05f5f87a0bf647a1e5b4446728e5f800788c6bb28b462bcd167f1de7f"},
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:b1870a8e30f0ac298d17fd546224348d13a648bcfa0cbc51dba7e5136c1af928"},
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:182a189212d41be0c960fd3299bf6731af2e771f8858cfb1be7ebcc17d60a254"},
@@ -4928,8 +4924,6 @@ files = [
{file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"},
{file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"},
{file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"},
{file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"},
{file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"},
{file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"},
{file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"},
{file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"},
@@ -5896,27 +5890,27 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "rerun-sdk"
version = "0.21.0"
version = "0.18.2"
description = "The Rerun Logging SDK"
optional = false
python-versions = ">=3.8"
python-versions = "<3.13,>=3.8"
files = [
{file = "rerun_sdk-0.21.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1e454ceea31c70ae9ec1bb26eaa82828661b7657ab4d2261ca0b94006d6a1975"},
{file = "rerun_sdk-0.21.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:84ecb77b0b5bac71b53e849801ff073de89fcd2f1e0ca0da62fb18fcbeceadf0"},
{file = "rerun_sdk-0.21.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:919d921165c3238490dbe5bf00a062c68fdd2c54dc14aac6a1914c82edb5d9c8"},
{file = "rerun_sdk-0.21.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:897649aadcab7014b78096f93c84c61c00a227b80adaf0dec279924b5aab53d8"},
{file = "rerun_sdk-0.21.0-cp38-abi3-win_amd64.whl", hash = "sha256:2060bdb536a198f0f04789ba5ba771e66587e7851d668b3dfab257a5efa16819"},
{file = "rerun_sdk-0.18.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bc4e73275f428e4e9feb8e85f88db7a9fd18b997b1570de62f949a926978f1b2"},
{file = "rerun_sdk-0.18.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:efbba40a59710ae83607cb0dc140398a35979c2d2acf5190c9def2ac4697f6a8"},
{file = "rerun_sdk-0.18.2-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:2a5e3b618b6d1bfde09bd5614a898995f3c318cc69d8f6d569924a2cd41536ce"},
{file = "rerun_sdk-0.18.2-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:8fdfc4c51ef2e75cb68d39e56f0d7c196eff250cb9a0260c07d5e2d6736e31b0"},
{file = "rerun_sdk-0.18.2-cp38-abi3-win_amd64.whl", hash = "sha256:c929ade91d3be301b26671b25e70fb529524ced915523d266641c6fc667a1eb5"},
]
[package.dependencies]
attrs = ">=23.1.0"
numpy = ">=1.23"
numpy = ">=1.23,<2"
pillow = ">=8.0.0"
pyarrow = ">=14.0.2"
typing-extensions = ">=4.5"
[package.extras]
notebook = ["rerun-notebook (==0.21.0)"]
notebook = ["rerun-notebook (==0.18.2)"]
tests = ["pytest (==7.1.2)"]
[[package]]
@@ -7575,4 +7569,4 @@ xarm = ["gym-xarm"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "ee60d9251f6a6253d0c371707a72a500a6053d7925c6898e6663d9320ad11503"
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"

View File

@@ -57,7 +57,7 @@ pytest-cov = {version = ">=5.0.0", optional = true}
datasets = ">=2.19.0"
imagecodecs = { version = ">=2024.1.1", optional = true }
pyav = ">=12.0.5"
rerun-sdk = ">=0.21.0"
rerun-sdk = ">=0.15.1"
deepdiff = ">=7.0.1"
flask = ">=3.0.3"
pandas = {version = ">=2.2.2", optional = true}

View File

@@ -0,0 +1,251 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import torch
from hydra import compose, initialize_config_dir
from torch import nn
from torch.utils.data import Dataset
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
from lerobot.scripts.train_hilserl_classifier import (
create_balanced_sampler,
train,
train_epoch,
validate,
)
class MockDataset(Dataset):
def __init__(self, data):
self.data = data
self.meta = MagicMock()
self.meta.stats = {}
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return len(self.data)
def make_dummy_model():
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
model = Classifier(config=model_config)
return model
def test_create_balanced_sampler():
# Mock dataset with imbalanced classes
data = [
{"label": 0},
{"label": 0},
{"label": 1},
{"label": 0},
{"label": 1},
{"label": 1},
{"label": 1},
{"label": 1},
]
dataset = MockDataset(data)
cfg = MagicMock()
cfg.training.label_key = "label"
sampler = create_balanced_sampler(dataset, cfg)
# Get weights from the sampler
weights = sampler.weights.float()
# Check that samples have appropriate weights
labels = [item["label"] for item in data]
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
class_weights = 1.0 / class_counts
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
# Test that the weights are correct
assert torch.allclose(weights, expected_weights)
def test_train_epoch():
model = make_dummy_model()
# Mock components
model.train = MagicMock()
train_loader = [
{
"image": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
optimizer = MagicMock()
grad_scaler = MagicMock()
device = torch.device("cpu")
logger = MagicMock()
step = 0
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call the function under test
train_epoch(
model,
train_loader,
criterion,
optimizer,
grad_scaler,
device,
logger,
step,
cfg,
)
# Check that model.train() was called
model.train.assert_called_once()
# Check that optimizer.zero_grad() was called
optimizer.zero_grad.assert_called()
# Check that logger.log_dict was called
logger.log_dict.assert_called()
def test_validate():
model = make_dummy_model()
# Mock components
model.eval = MagicMock()
val_loader = [
{
"image": torch.rand(2, 3, 224, 224),
"label": torch.tensor([0.0, 1.0]),
}
]
criterion = nn.BCEWithLogitsLoss()
device = torch.device("cpu")
logger = MagicMock()
cfg = MagicMock()
cfg.training.image_key = "image"
cfg.training.label_key = "label"
cfg.training.use_amp = False
# Call validate
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
# Check that model.eval() was called
model.eval.assert_called_once()
# Check accuracy/eval_info are calculated and of the correct type
assert isinstance(accuracy, float)
assert isinstance(eval_info, dict)
@pytest.mark.parametrize("resume", [True, False])
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
def test_resume_function(
mock_make_policy,
mock_dataset,
mock_logger,
mock_get_last_pretrained_model_dir,
mock_get_last_checkpoint_dir,
mock_init_hydra_config,
resume,
):
# Initialize Hydra
test_file_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
cfg = compose(
config_name="reward_classifier",
overrides=[
"device=cpu",
"seed=42",
f"output_dir={tempfile.mkdtemp()}",
"wandb.enable=False",
f"resume={resume}",
"dataset_repo_id=dataset_repo_id",
"train_split_proportion=0.8",
"training.num_workers=0",
"training.batch_size=2",
"training.image_key=image",
"training.label_key=label",
"training.use_amp=False",
"training.num_epochs=1",
"eval.batch_size=2",
],
)
# Mock the init_hydra_config function to return cfg
mock_init_hydra_config.return_value = cfg
# Mock dataset
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
mock_dataset.return_value = dataset
# Mock checkpoint handling
mock_checkpoint_dir = MagicMock(spec=Path)
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
# Mock logger
logger = MagicMock()
resumed_step = 1000
if resume:
logger.load_last_training_state.return_value = resumed_step
else:
logger.load_last_training_state.return_value = 0
mock_logger.return_value = logger
# Instantiate the model and set make_policy to return it
model = make_dummy_model()
mock_make_policy.return_value = model
# Call train
train(cfg)
# Check that checkpoint handling methods were called
if resume:
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
mock_checkpoint_dir.exists.assert_called_once()
logger.load_last_training_state.assert_called_once()
else:
mock_get_last_checkpoint_dir.assert_not_called()
mock_get_last_pretrained_model_dir.assert_not_called()
mock_checkpoint_dir.exists.assert_not_called()
logger.load_last_training_state.assert_not_called()
# Collect the steps from logger.log_dict calls
train_log_calls = logger.log_dict.call_args_list
# Extract the steps used in the train logging
steps = []
for call in train_log_calls:
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
if mode == "train":
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
steps.append(step)
expected_start_step = resumed_step if resume else 0
# Calculate expected_steps
train_size = int(cfg.train_split_proportion * len(dataset))
batch_size = cfg.training.batch_size
num_batches = (train_size + batch_size - 1) // batch_size
expected_steps = [expected_start_step + i for i in range(num_batches)]
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"

View File

@@ -14,25 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import DatasetCard
from lerobot.common.datasets.utils import create_lerobot_dataset_card
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
def test_default_parameters():
card = create_lerobot_dataset_card()
assert isinstance(card, DatasetCard)
assert card.data.tags == ["LeRobot"]
assert card.data.task_categories == ["robotics"]
assert card.data.configs == [
{
"config_name": "default",
"data_files": "data/*/*.parquet",
}
]
def test_with_tags():
tags = ["tag1", "tag2"]
card = create_lerobot_dataset_card(tags=tags)
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
visualize_dataset_html(
dataset,
episodes=[0],
output_dir=output_dir,
serve=False,
)
assert (output_dir / "static" / "episode_0.csv").exists()