Compare commits
13 Commits
user/rcade
...
temp_branc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
def42ff487 | ||
|
|
c9af8e36a7 | ||
|
|
ed66c92383 | ||
|
|
668d493bf9 | ||
|
|
67f4d7ea7a | ||
|
|
4b0c88ff8e | ||
|
|
b19fef9d18 | ||
|
|
1612e00e63 | ||
|
|
c3bc136420 | ||
|
|
1020bc3108 | ||
|
|
7fcf638c0d | ||
|
|
e35546f58e | ||
|
|
1aa8d4ac91 |
4
.github/workflows/quality.yml
vendored
4
.github/workflows/quality.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Poetry check
|
||||
run: poetry check
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Install poetry
|
||||
run: pipx install "poetry<2.0.0"
|
||||
run: pipx install poetry
|
||||
|
||||
- name: Install poetry-relax
|
||||
run: poetry self add poetry-relax
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
|
||||
@@ -21,7 +21,7 @@ How to decode videos?
|
||||
|
||||
## Variables
|
||||
**Image content & size**
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
|
||||
For these reasons, we run this benchmark on four representative datasets:
|
||||
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
|
||||
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
|
||||
@@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie
|
||||
|
||||
Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.
|
||||
|
||||
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
|
||||
Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
|
||||
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),
|
||||
|
||||
However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
|
||||
@@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
|
||||
**Average Structural Similarity Index Measure (higher is better)**
|
||||
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.
|
||||
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
|
||||
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
|
||||
- `yuv420p` is more widely supported across various platforms, including web browsers.
|
||||
- `yuv444p` offers higher color fidelity but might not be supported as broadly.
|
||||
|
||||
@@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
|
||||
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
|
||||
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).
|
||||
|
||||
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.
|
||||
See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
|
||||
|
||||
Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
|
||||
- `torchaudio`
|
||||
|
||||
@@ -13,7 +13,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
sed gawk grep curl wget zip unzip \
|
||||
tcpdump sysstat screen tmux \
|
||||
libglib2.0-0 libgl1-mesa-glx libegl1-mesa \
|
||||
speech-dispatcher portaudio19-dev \
|
||||
speech-dispatcher \
|
||||
python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
@@ -58,7 +58,7 @@ RUN (type -p wget >/dev/null || (apt update && apt-get install wget -y)) \
|
||||
RUN ln -s /usr/bin/python3 /usr/bin/python
|
||||
|
||||
# Install poetry
|
||||
RUN curl -sSL https://install.python-poetry.org | python - --version 1.8.5
|
||||
RUN curl -sSL https://install.python-poetry.org | python -
|
||||
ENV PATH="/root/.local/bin:$PATH"
|
||||
RUN echo 'if [ "$HOME" != "/root" ]; then ln -sf /root/.local/bin/poetry $HOME/.local/bin/poetry; fi' >> /root/.bashrc
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
@@ -1,31 +1,25 @@
|
||||
# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot
|
||||
This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot.
|
||||
|
||||
|
||||
## A. Source the parts
|
||||
## Source the parts
|
||||
|
||||
Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already.
|
||||
|
||||
**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly.
|
||||
|
||||
## B. Install LeRobot
|
||||
## Install LeRobot
|
||||
|
||||
On your computer:
|
||||
|
||||
1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install):
|
||||
```bash
|
||||
mkdir -p ~/miniconda3
|
||||
# Linux:
|
||||
wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
|
||||
# Mac M-series:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
||||
# Mac Intel:
|
||||
# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh
|
||||
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
||||
rm ~/miniconda3/miniconda.sh
|
||||
~/miniconda3/bin/conda init bash
|
||||
```
|
||||
|
||||
2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell
|
||||
2. Restart shell or `source ~/.bashrc`
|
||||
|
||||
3. Create and activate a fresh conda environment for lerobot
|
||||
```bash
|
||||
@@ -42,30 +36,23 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot
|
||||
cd ~/lerobot && pip install -e ".[feetech]"
|
||||
```
|
||||
|
||||
*For Linux only (not Mac)*: install extra dependencies for recording datasets:
|
||||
For Linux only (not Mac), install extra dependencies for recording datasets:
|
||||
```bash
|
||||
conda install -y -c conda-forge ffmpeg
|
||||
pip uninstall -y opencv-python
|
||||
conda install -y -c conda-forge "opencv>=4.10.0"
|
||||
```
|
||||
|
||||
## C. Configure the motors
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated to each arm
|
||||
Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below.
|
||||
|
||||
Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm.
|
||||
|
||||
#### a. Run the script to find ports
|
||||
|
||||
Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below.
|
||||
|
||||
To find the port for each bus servo adapter, run the utility script:
|
||||
**Find USB ports associated to your arms**
|
||||
To find the correct ports for each arm, run the utility script twice:
|
||||
```bash
|
||||
python lerobot/scripts/find_motors_bus_port.py
|
||||
```
|
||||
|
||||
#### b. Example outputs
|
||||
|
||||
Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -77,6 +64,7 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done.
|
||||
The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux):
|
||||
```
|
||||
Finding all available ports for the MotorBus.
|
||||
@@ -89,20 +77,13 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081
|
||||
Reconnect the usb cable.
|
||||
```
|
||||
|
||||
#### c. Troubleshooting
|
||||
On Linux, you might need to give access to the USB ports by running:
|
||||
Troubleshooting: On Linux, you might need to give access to the USB ports by running:
|
||||
```bash
|
||||
sudo chmod 666 /dev/ttyACM0
|
||||
sudo chmod 666 /dev/ttyACM1
|
||||
```
|
||||
|
||||
#### d. Update YAML file
|
||||
|
||||
Now that you have the ports, modify the *port* sections in `so100.yaml`
|
||||
|
||||
### 2. Configure the motors
|
||||
|
||||
#### a. Set IDs for all 12 motors
|
||||
**Configure your motors**
|
||||
Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate:
|
||||
```bash
|
||||
python lerobot/scripts/configure_motor.py \
|
||||
@@ -113,7 +94,7 @@ python lerobot/scripts/configure_motor.py \
|
||||
--ID 1
|
||||
```
|
||||
|
||||
*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).*
|
||||
Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).
|
||||
|
||||
Then unplug your motor and plug the second motor and set its ID to 2.
|
||||
```bash
|
||||
@@ -127,25 +108,23 @@ python lerobot/scripts/configure_motor.py \
|
||||
|
||||
Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm.
|
||||
|
||||
**Remove the gears of the 6 leader motors**
|
||||
Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### b. Remove the gears of the 6 leader motors
|
||||
|
||||
Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm.
|
||||
|
||||
#### c. Add motor horn to all 12 motors
|
||||
Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
**Add motor horn to the motors**
|
||||
Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30.
|
||||
Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated.
|
||||
|
||||
## D. Assemble the arms
|
||||
## Assemble the arms
|
||||
|
||||
Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm.
|
||||
|
||||
## E. Calibrate
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another.
|
||||
|
||||
#### a. Manual calibration of follower arm
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
**Manual calibration of follower arm**
|
||||
/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now.
|
||||
|
||||
You will need to move the follower arm to these positions sequentially:
|
||||
|
||||
@@ -160,8 +139,8 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_follower
|
||||
```
|
||||
|
||||
#### b. Manual calibration of leader arm
|
||||
Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
**Manual calibration of leader arm**
|
||||
Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially:
|
||||
|
||||
| 1. Zero position | 2. Rotated position | 3. Rest position |
|
||||
|---|---|---|
|
||||
@@ -174,7 +153,7 @@ python lerobot/scripts/control_robot.py calibrate \
|
||||
--robot-overrides '~cameras' --arms main_leader
|
||||
```
|
||||
|
||||
## F. Teleoperate
|
||||
## Teleoperate
|
||||
|
||||
**Simple teleop**
|
||||
Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras):
|
||||
@@ -186,14 +165,14 @@ python lerobot/scripts/control_robot.py teleoperate \
|
||||
```
|
||||
|
||||
|
||||
#### a. Teleop with displaying cameras
|
||||
**Teleop with displaying cameras**
|
||||
Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset.
|
||||
```bash
|
||||
python lerobot/scripts/control_robot.py teleoperate \
|
||||
--robot-path lerobot/configs/robot/so100.yaml
|
||||
```
|
||||
|
||||
## G. Record a dataset
|
||||
## Record a dataset
|
||||
|
||||
Once you're familiar with teleoperation, you can record your first dataset with SO-100.
|
||||
|
||||
@@ -222,7 +201,7 @@ python lerobot/scripts/control_robot.py record \
|
||||
--push-to-hub 1
|
||||
```
|
||||
|
||||
## H. Visualize a dataset
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by:
|
||||
```bash
|
||||
@@ -235,7 +214,7 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
--repo-id ${HF_USER}/so100_test
|
||||
```
|
||||
|
||||
## I. Replay an episode
|
||||
## Replay an episode
|
||||
|
||||
Now try to replay the first episode on your robot:
|
||||
```bash
|
||||
@@ -246,7 +225,7 @@ python lerobot/scripts/control_robot.py replay \
|
||||
--episode 0
|
||||
```
|
||||
|
||||
## J. Train a policy
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
```bash
|
||||
@@ -269,7 +248,7 @@ Let's explain it:
|
||||
|
||||
Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`.
|
||||
|
||||
## K. Evaluate your policy
|
||||
## Evaluate your policy
|
||||
|
||||
You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes:
|
||||
```bash
|
||||
@@ -289,7 +268,7 @@ As you can see, it's almost the same command as previously used to record your t
|
||||
1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`).
|
||||
2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`).
|
||||
|
||||
## L. More Information
|
||||
## More
|
||||
|
||||
Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot.
|
||||
|
||||
|
||||
83
examples/12_train_hilserl_classifier.md
Normal file
83
examples/12_train_hilserl_classifier.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# Training a HIL-SERL Reward Classifier with LeRobot
|
||||
|
||||
This tutorial provides step-by-step instructions for training a reward classifier using LeRobot.
|
||||
|
||||
---
|
||||
|
||||
## Training Script Overview
|
||||
|
||||
LeRobot includes a ready-to-use training script located at [`lerobot/scripts/train_hilserl_classifier.py`](../../lerobot/scripts/train_hilserl_classifier.py). Here's an outline of its workflow:
|
||||
|
||||
1. **Configuration Loading**
|
||||
The script uses Hydra to load a configuration file for subsequent steps. (Details on Hydra follow below.)
|
||||
|
||||
2. **Dataset Initialization**
|
||||
It loads a `LeRobotDataset` containing images and rewards. To optimize performance, a weighted random sampler is used to balance class sampling.
|
||||
|
||||
3. **Classifier Initialization**
|
||||
A lightweight classification head is built on top of a frozen, pretrained image encoder from HuggingFace. The classifier outputs either:
|
||||
- A single probability (binary classification), or
|
||||
- Logits (multi-class classification).
|
||||
|
||||
4. **Training Loop Execution**
|
||||
The script performs:
|
||||
- Forward and backward passes,
|
||||
- Optimization steps,
|
||||
- Periodic logging, evaluation, and checkpoint saving.
|
||||
|
||||
---
|
||||
|
||||
## Configuring with Hydra
|
||||
|
||||
For detailed information about Hydra usage, refer to [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md). However, note that training the reward classifier differs slightly and requires a separate configuration file.
|
||||
|
||||
### Config File Setup
|
||||
|
||||
The default `default.yaml` cannot launch the reward classifier training directly. Instead, you need a configuration file like [`lerobot/configs/policy/hilserl_classifier.yaml`](../../lerobot/configs/policy/hilserl_classifier.yaml), with the following adjustment:
|
||||
|
||||
Replace the `dataset_repo_id` field with the identifier for your dataset, which contains images and sparse rewards:
|
||||
|
||||
```yaml
|
||||
# Example: lerobot/configs/policy/reward_classifier.yaml
|
||||
dataset_repo_id: "my_dataset_repo_id"
|
||||
## Typical logs and metrics
|
||||
```
|
||||
When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.
|
||||
|
||||
After that, you will see training log like this one:
|
||||
|
||||
```
|
||||
[2024-11-29 18:26:36,999][root][INFO] -
|
||||
Epoch 5/5
|
||||
Training: 82%|██████████████████████████████████████████████████████████████████████████████▋ | 91/111 [00:50<00:09, 2.04it/s, loss=0.2999, acc=69.99%]
|
||||
```
|
||||
|
||||
or evaluation log like:
|
||||
|
||||
```
|
||||
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:20<00:00, 1.37it/s]
|
||||
```
|
||||
|
||||
### Metrics Tracking with Weights & Biases (WandB)
|
||||
|
||||
If `wandb.enable` is set to `true`, the training and evaluation logs will also be saved in WandB. This allows you to track key metrics in real-time, including:
|
||||
|
||||
- **Training Metrics**:
|
||||
- `train/accuracy`
|
||||
- `train/loss`
|
||||
- `train/dataloading_s`
|
||||
- **Evaluation Metrics**:
|
||||
- `eval/accuracy`
|
||||
- `eval/loss`
|
||||
- `eval/eval_s`
|
||||
|
||||
#### Additional Features
|
||||
|
||||
You can also log sample predictions during evaluation. Each logged sample will include:
|
||||
|
||||
- The **input image**.
|
||||
- The **predicted label**.
|
||||
- The **true label**.
|
||||
- The **classifier's "confidence" (logits/probability)**.
|
||||
|
||||
These logs can be useful for diagnosing and debugging performance issues.
|
||||
@@ -1,213 +0,0 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
|
||||
|
||||
def create_empty_dataset(dataset_name, robot_type, mode="video", has_velocity=False, has_effort=False):
|
||||
motors = [
|
||||
# TODO(rcadene): verify
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=f"cadene/{dataset_name}_v2",
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_cameras(hdf5_files):
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
# TODO(rcadene): add depth
|
||||
rgb_cameras = [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
return rgb_cameras
|
||||
|
||||
|
||||
def has_velocity(hdf5_files):
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files):
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep, cameras):
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(ep_path):
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(ep)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(dataset, hdf5_files, task, episodes=None):
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(raw_dir, raw_repo_id, repo_id, episodes: list[int] | None = None, push_to_hub=True):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
raw_dir = Path(raw_dir)
|
||||
if not raw_dir.exists():
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset_name = repo_id.split("/")[1]
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if "mobile" in dataset_name else "aloha",
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task="DEBUG",
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raw_repo_id = "lerobot-raw/aloha_sim_insertion_human_raw"
|
||||
repo_id = "cadene/aloha_sim_insertion_human_v2"
|
||||
port_aloha(f"data/{raw_repo_id}", raw_repo_id, repo_id, episodes=[0, 1], push_to_hub=False)
|
||||
@@ -291,7 +291,7 @@ class LeRobotDatasetMetadata:
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
if robot is not None:
|
||||
features = get_features_from_robot(robot, use_videos)
|
||||
features = {**(features or {}), **get_features_from_robot(robot)}
|
||||
robot_type = robot.robot_type
|
||||
if not all(cam.fps == fps for cam in robot.cameras.values()):
|
||||
logging.warning(
|
||||
|
||||
@@ -17,11 +17,9 @@ import importlib.resources
|
||||
import json
|
||||
import logging
|
||||
import textwrap
|
||||
from collections.abc import Iterator
|
||||
from itertools import accumulate
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import datasets
|
||||
@@ -479,6 +477,7 @@ def create_lerobot_dataset_card(
|
||||
Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses.
|
||||
"""
|
||||
card_tags = ["LeRobot"]
|
||||
card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md")
|
||||
|
||||
if tags:
|
||||
card_tags += tags
|
||||
@@ -498,65 +497,8 @@ def create_lerobot_dataset_card(
|
||||
],
|
||||
)
|
||||
|
||||
card_template = (importlib.resources.files("lerobot.common.datasets") / "card_template.md").read_text()
|
||||
|
||||
return DatasetCard.from_template(
|
||||
card_data=card_data,
|
||||
template_str=card_template,
|
||||
template_path=str(card_template_path),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class IterableNamespace(SimpleNamespace):
|
||||
"""
|
||||
A namespace object that supports both dictionary-like iteration and dot notation access.
|
||||
Automatically converts nested dictionaries into IterableNamespaces.
|
||||
|
||||
This class extends SimpleNamespace to provide:
|
||||
- Dictionary-style iteration over keys
|
||||
- Access to items via both dot notation (obj.key) and brackets (obj["key"])
|
||||
- Dictionary-like methods: items(), keys(), values()
|
||||
- Recursive conversion of nested dictionaries
|
||||
|
||||
Args:
|
||||
dictionary: Optional dictionary to initialize the namespace
|
||||
**kwargs: Additional keyword arguments passed to SimpleNamespace
|
||||
|
||||
Examples:
|
||||
>>> data = {"name": "Alice", "details": {"age": 25}}
|
||||
>>> ns = IterableNamespace(data)
|
||||
>>> ns.name
|
||||
'Alice'
|
||||
>>> ns.details.age
|
||||
25
|
||||
>>> list(ns.keys())
|
||||
['name', 'details']
|
||||
>>> for key, value in ns.items():
|
||||
... print(f"{key}: {value}")
|
||||
name: Alice
|
||||
details: IterableNamespace(age=25)
|
||||
"""
|
||||
|
||||
def __init__(self, dictionary: dict[str, Any] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if dictionary is not None:
|
||||
for key, value in dictionary.items():
|
||||
if isinstance(value, dict):
|
||||
setattr(self, key, IterableNamespace(value))
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(vars(self))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return vars(self)[key]
|
||||
|
||||
def items(self):
|
||||
return vars(self).items()
|
||||
|
||||
def values(self):
|
||||
return vars(self).values()
|
||||
|
||||
def keys(self):
|
||||
return vars(self).keys()
|
||||
|
||||
@@ -159,11 +159,11 @@ DATASETS = {
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup": {
|
||||
"single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.",
|
||||
"single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_vinh_cup_left": {
|
||||
"single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.",
|
||||
"single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.",
|
||||
**ALOHA_STATIC_INFO,
|
||||
},
|
||||
"aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO},
|
||||
|
||||
@@ -25,6 +25,7 @@ from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import wandb
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
@@ -107,8 +108,6 @@ class Logger:
|
||||
self._wandb = None
|
||||
else:
|
||||
os.environ["WANDB_SILENT"] = "true"
|
||||
import wandb
|
||||
|
||||
wandb_run_id = None
|
||||
if cfg.resume:
|
||||
wandb_run_id = get_wandb_run_id_from_filesystem(self.checkpoints_dir)
|
||||
@@ -232,7 +231,7 @@ class Logger:
|
||||
# TODO(alexander-soare): Add local text log.
|
||||
if self._wandb is not None:
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int, float, str, wandb.Table)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.'
|
||||
)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassifierConfig:
|
||||
"""Configuration for the Classifier model."""
|
||||
|
||||
num_classes: int = 2
|
||||
hidden_dim: int = 256
|
||||
dropout_rate: float = 0.1
|
||||
model_name: str = "microsoft/resnet-50"
|
||||
device: str = "cuda" if torch.cuda.is_available() else "mps"
|
||||
model_type: str = "cnn" # "transformer" or "cnn"
|
||||
|
||||
def save_pretrained(self, save_dir):
|
||||
"""Save config to json file."""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Convert to dict and save as JSON
|
||||
config_dict = asdict(self)
|
||||
with open(os.path.join(save_dir, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path):
|
||||
"""Load config from json file."""
|
||||
config_file = os.path.join(pretrained_model_name_or_path, "config.json")
|
||||
|
||||
with open(config_file) as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
return cls(**config_dict)
|
||||
@@ -0,0 +1,134 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from torch import Tensor, nn
|
||||
from transformers import AutoImageProcessor, AutoModel
|
||||
|
||||
from .configuration_classifier import ClassifierConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClassifierOutput:
|
||||
"""Wrapper for classifier outputs with additional metadata."""
|
||||
|
||||
def __init__(
|
||||
self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None
|
||||
):
|
||||
self.logits = logits
|
||||
self.probabilities = probabilities
|
||||
self.hidden_states = hidden_states
|
||||
|
||||
|
||||
class Classifier(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
# Add Hub metadata
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "vision-classifier"],
|
||||
):
|
||||
"""Image classifier built on top of a pre-trained encoder."""
|
||||
|
||||
# Add name attribute for factory
|
||||
name = "classifier"
|
||||
|
||||
def __init__(self, config: ClassifierConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True)
|
||||
# Extract vision model if we're given a multimodal model
|
||||
if hasattr(encoder, "vision_model"):
|
||||
logging.info("Multimodal model detected - using vision encoder only")
|
||||
self.encoder = encoder.vision_model
|
||||
self.vision_config = encoder.config.vision_config
|
||||
else:
|
||||
self.encoder = encoder
|
||||
self.vision_config = getattr(encoder, "config", None)
|
||||
|
||||
# Model type from config
|
||||
self.is_cnn = self.config.model_type == "cnn"
|
||||
|
||||
# For CNNs, initialize backbone
|
||||
if self.is_cnn:
|
||||
self._setup_cnn_backbone()
|
||||
|
||||
self._freeze_encoder()
|
||||
self._build_classifier_head()
|
||||
|
||||
def _setup_cnn_backbone(self):
|
||||
"""Set up CNN encoder"""
|
||||
if hasattr(self.encoder, "fc"):
|
||||
self.feature_dim = self.encoder.fc.in_features
|
||||
self.encoder = nn.Sequential(*list(self.encoder.children())[:-1])
|
||||
elif hasattr(self.encoder.config, "hidden_sizes"):
|
||||
self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension
|
||||
else:
|
||||
raise ValueError("Unsupported CNN architecture")
|
||||
|
||||
def _freeze_encoder(self) -> None:
|
||||
"""Freeze the encoder parameters."""
|
||||
for param in self.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _build_classifier_head(self) -> None:
|
||||
"""Initialize the classifier head architecture."""
|
||||
# Get input dimension based on model type
|
||||
if self.is_cnn:
|
||||
input_dim = self.feature_dim
|
||||
else: # Transformer models
|
||||
if hasattr(self.encoder.config, "hidden_size"):
|
||||
input_dim = self.encoder.config.hidden_size
|
||||
else:
|
||||
raise ValueError("Unsupported transformer architecture since hidden_size is not found")
|
||||
|
||||
self.classifier_head = nn.Sequential(
|
||||
nn.Linear(input_dim, self.config.hidden_dim),
|
||||
nn.Dropout(self.config.dropout_rate),
|
||||
nn.LayerNorm(self.config.hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes),
|
||||
)
|
||||
|
||||
def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Extract the appropriate output from the encoder."""
|
||||
# Process images with the processor (handles resizing and normalization)
|
||||
processed = self.processor(
|
||||
images=x, # LeRobotDataset already provides proper tensor format
|
||||
return_tensors="pt",
|
||||
)
|
||||
processed = processed["pixel_values"].to(x.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.is_cnn:
|
||||
# The HF ResNet applies pooling internally
|
||||
outputs = self.encoder(processed)
|
||||
# Get pooled output directly
|
||||
features = outputs.pooler_output
|
||||
|
||||
if features.dim() > 2:
|
||||
features = features.squeeze(-1).squeeze(-1)
|
||||
return features
|
||||
else: # Transformer models
|
||||
outputs = self.encoder(processed)
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
return outputs.pooler_output
|
||||
return outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> ClassifierOutput:
|
||||
"""Forward pass of the classifier."""
|
||||
# For training, we expect input to be a tensor directly from LeRobotDataset
|
||||
encoder_output = self._get_encoder_output(x)
|
||||
logits = self.classifier_head(encoder_output)
|
||||
|
||||
if self.config.num_classes == 2:
|
||||
logits = logits.squeeze(-1)
|
||||
probabilities = torch.sigmoid(logits)
|
||||
else:
|
||||
probabilities = torch.softmax(logits, dim=-1)
|
||||
|
||||
return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_output)
|
||||
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
23
lerobot/common/policies/hilserl/configuration_hilserl.py
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HILSerlConfig:
|
||||
pass
|
||||
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
29
lerobot/common/policies/hilserl/modeling_hilserl.py
Normal file
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class HILSerlPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "hilserl"],
|
||||
):
|
||||
pass
|
||||
39
lerobot/common/policies/sac/configuration_sac.py
Normal file
39
lerobot/common/policies/sac/configuration_sac.py
Normal file
@@ -0,0 +1,39 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SACConfig:
|
||||
discount = 0.99
|
||||
temperature_init = 1.0
|
||||
num_critics = 2
|
||||
critic_lr = 3e-4
|
||||
actor_lr = 3e-4
|
||||
critic_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
actor_network_kwargs = {
|
||||
"hidden_dims": [256, 256],
|
||||
"activate_final": True,
|
||||
}
|
||||
policy_kwargs = {
|
||||
"tanh_squash_distribution": True,
|
||||
"std_parameterization": "uniform",
|
||||
}
|
||||
683
lerobot/common/policies/sac/modeling_sac.py
Normal file
683
lerobot/common/policies/sac/modeling_sac.py
Normal file
@@ -0,0 +1,683 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: (1) better device management
|
||||
|
||||
from collections import deque
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import einops
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.sac.configuration_sac import SACConfig
|
||||
import numpy as np
|
||||
from typing import Callable, Optional, Tuple, Sequence
|
||||
|
||||
|
||||
|
||||
class SACPolicy(
|
||||
nn.Module,
|
||||
PyTorchModelHubMixin,
|
||||
library_name="lerobot",
|
||||
repo_url="https://github.com/huggingface/lerobot",
|
||||
tags=["robotics", "RL", "SAC"],
|
||||
):
|
||||
|
||||
def __init__(
|
||||
self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = SACConfig()
|
||||
self.config = config
|
||||
|
||||
if config.input_normalization_modes is not None:
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_shapes, config.input_normalization_modes, dataset_stats
|
||||
)
|
||||
else:
|
||||
self.normalize_inputs = nn.Identity()
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
self.unnormalize_outputs = Unnormalize(
|
||||
config.output_shapes, config.output_normalization_modes, dataset_stats
|
||||
)
|
||||
encoder = SACObservationEncoder(config)
|
||||
# Define networks
|
||||
critic_nets = []
|
||||
for _ in range(config.num_critics):
|
||||
critic_net = Critic(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.critic_network_kwargs)
|
||||
)
|
||||
critic_nets.append(critic_net)
|
||||
|
||||
self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics)
|
||||
self.critic_target = deepcopy(self.critic_ensemble)
|
||||
|
||||
self.actor_network = Policy(
|
||||
encoder=encoder,
|
||||
network=MLP(**config.actor_network_kwargs),
|
||||
action_dim=config.output_shapes["action"][0],
|
||||
**config.policy_kwargs
|
||||
)
|
||||
|
||||
self.temperature = LagrangeMultiplier(init_value=config.temperature_init)
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Clear observation and action queues. Should be called on `env.reset()`
|
||||
queues are populated during rollout of the policy, they contain the n latest observations and actions
|
||||
"""
|
||||
|
||||
self._queues = {
|
||||
"observation.state": deque(maxlen=1),
|
||||
"action": deque(maxlen=1),
|
||||
}
|
||||
if self._use_image:
|
||||
self._queues["observation.image"] = deque(maxlen=1)
|
||||
if self._use_env_state:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=1)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
actions, _ = self.actor_network(batch['observations'])###
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
|
||||
"""Run the batch through the model and compute the loss.
|
||||
|
||||
Returns a dictionary with loss as a tensor, and other information as native floats.
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
# batch shape is (b, 2, ...) where index 1 returns the current observation and
|
||||
# the next observation for caluculating the right td index.
|
||||
actions = batch["action"][:, 0]
|
||||
rewards = batch["next.reward"][:, 0]
|
||||
observations = {}
|
||||
next_observations = {}
|
||||
for k in batch:
|
||||
if k.startswith("observation."):
|
||||
observations[k] = batch[k][:, 0]
|
||||
next_observations[k] = batch[k][:, 1]
|
||||
|
||||
# perform image augmentation
|
||||
|
||||
# reward bias
|
||||
# from HIL-SERL code base
|
||||
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch
|
||||
|
||||
|
||||
# calculate critics loss
|
||||
# 1- compute actions from policy
|
||||
action_preds, log_probs = self.actor_network(observations)
|
||||
# 2- compute q targets
|
||||
q_targets = self.target_qs(next_observations, action_preds)
|
||||
|
||||
# critics subsample size
|
||||
min_q = q_targets.min(dim=0)
|
||||
|
||||
# backup entropy
|
||||
td_target = rewards + self.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
q_preds = self.critic_ensemble(observations, actions)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
q_preds,
|
||||
einops.repeat(td_target, "t b -> e t b", e=q_preds.shape[0]),
|
||||
reduction="none",
|
||||
).sum(0) # sum over ensemble
|
||||
# `q_preds_ensemble` depends on the first observation and the actions.
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
# q_targets depends on the reward and the next observations.
|
||||
* ~batch["next.reward_is_pad"]
|
||||
* ~batch["observation.state_is_pad"][1:]
|
||||
).sum(0).mean()
|
||||
|
||||
# calculate actors loss
|
||||
# 1- temperature
|
||||
temperature = self.temperature()
|
||||
|
||||
# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
|
||||
actions, log_probs = self.actor_network(observations) \
|
||||
|
||||
# 3- get q-value predictions
|
||||
with torch.no_grad():
|
||||
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
|
||||
actor_loss = (
|
||||
-(q_preds - temperature * log_probs).mean()
|
||||
* ~batch["observation.state_is_pad"][0]
|
||||
* ~batch["action_is_pad"]
|
||||
).mean()
|
||||
|
||||
|
||||
# calculate temperature loss
|
||||
# 1- calculate entropy
|
||||
entropy = -log_probs.mean()
|
||||
temperature_loss = temperature * (entropy - self.target_entropy).mean()
|
||||
|
||||
loss = critics_loss + actor_loss + temperature_loss
|
||||
|
||||
return {
|
||||
"critics_loss": critics_loss.item(),
|
||||
"actor_loss": actor_loss.item(),
|
||||
"temperature_loss": temperature_loss.item(),
|
||||
"temperature": temperature.item(),
|
||||
"entropy": entropy.item(),
|
||||
"loss": loss,
|
||||
|
||||
}
|
||||
|
||||
def update(self):
|
||||
self.critic_target.lerp_(self.critic_ensemble, self.config.critic_target_update_weight)
|
||||
#for target_param, param in zip(self.critic_target.parameters(), self.critic_ensemble.parameters()):
|
||||
# target_param.data.copy_(target_param.data * (1.0 - self.config.critic_target_update_weight) + param.data * self.critic_target_update_weight)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: SACConfig,
|
||||
activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(),
|
||||
activate_final: bool = False,
|
||||
dropout_rate: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.activate_final = config.activate_final
|
||||
layers = []
|
||||
|
||||
for i, size in enumerate(config.network_hidden_dims):
|
||||
layers.append(nn.Linear(config.network_hidden_dims[i-1] if i > 0 else config.network_hidden_dims[0], size))
|
||||
|
||||
if i + 1 < len(config.network_hidden_dims) or activate_final:
|
||||
if dropout_rate is not None and dropout_rate > 0:
|
||||
layers.append(nn.Dropout(p=dropout_rate))
|
||||
layers.append(nn.LayerNorm(size))
|
||||
layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)())
|
||||
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor, train: bool = False) -> torch.Tensor:
|
||||
# in training mode or not. TODO: find better way to do this
|
||||
self.train(train)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.init_final = init_final
|
||||
self.activate_final = activate_final
|
||||
|
||||
# Output layer
|
||||
if init_final is not None:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
nn.init.uniform_(self.output_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.output_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.output_layer = nn.Linear(network.net[-3].out_features, 1)
|
||||
else:
|
||||
self.output_layer = nn.Linear(network.net[-2].out_features, 1)
|
||||
orthogonal_init()(self.output_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
self.train(train)
|
||||
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if self.encoder is not None:
|
||||
obs_enc = self.encoder(observations)
|
||||
else:
|
||||
obs_enc = observations
|
||||
|
||||
inputs = torch.cat([obs_enc, actions], dim=-1)
|
||||
x = self.network(inputs)
|
||||
value = self.output_layer(x)
|
||||
return value.squeeze(-1)
|
||||
|
||||
def q_value_ensemble(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
actions: torch.Tensor,
|
||||
train: bool = False
|
||||
) -> torch.Tensor:
|
||||
observations = observations.to(self.device)
|
||||
actions = actions.to(self.device)
|
||||
|
||||
if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
|
||||
batch_size, num_actions = actions.shape[:2]
|
||||
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
|
||||
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
|
||||
actions_flat = actions.reshape(-1, actions.shape[-1])
|
||||
q_values = self(obs_flat, actions_flat, train)
|
||||
return q_values.reshape(batch_size, num_actions)
|
||||
else:
|
||||
return self(observations, actions, train)
|
||||
|
||||
|
||||
class Policy(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: Optional[nn.Module],
|
||||
network: nn.Module,
|
||||
action_dim: int,
|
||||
std_parameterization: str = "exp",
|
||||
std_min: float = 1e-5,
|
||||
std_max: float = 10.0,
|
||||
tanh_squash_distribution: bool = False,
|
||||
fixed_std: Optional[torch.Tensor] = None,
|
||||
init_final: Optional[float] = None,
|
||||
activate_final: bool = False,
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
self.encoder = encoder
|
||||
self.network = network
|
||||
self.action_dim = action_dim
|
||||
self.std_parameterization = std_parameterization
|
||||
self.std_min = std_min
|
||||
self.std_max = std_max
|
||||
self.tanh_squash_distribution = tanh_squash_distribution
|
||||
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
|
||||
self.activate_final = activate_final
|
||||
|
||||
# Mean layer
|
||||
if self.activate_final:
|
||||
self.mean_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.mean_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.mean_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.mean_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.mean_layer.weight)
|
||||
|
||||
# Standard deviation layer or parameter
|
||||
if fixed_std is None:
|
||||
if std_parameterization == "uniform":
|
||||
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
|
||||
else:
|
||||
if self.activate_final:
|
||||
self.std_layer = nn.Linear(network.net[-3].out_features, action_dim)
|
||||
else:
|
||||
self.std_layer = nn.Linear(network.net[-2].out_features, action_dim)
|
||||
if init_final is not None:
|
||||
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
|
||||
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
|
||||
else:
|
||||
orthogonal_init()(self.std_layer.weight)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
observations: torch.Tensor,
|
||||
temperature: float = 1.0,
|
||||
train: bool = False,
|
||||
non_squash_distribution: bool = False
|
||||
) -> torch.distributions.Distribution:
|
||||
self.train(train)
|
||||
|
||||
# Encode observations if encoder exists
|
||||
if self.encoder is not None:
|
||||
with torch.set_grad_enabled(train):
|
||||
obs_enc = self.encoder(observations, train=train)
|
||||
else:
|
||||
obs_enc = observations
|
||||
# Get network outputs
|
||||
outputs = self.network(obs_enc)
|
||||
means = self.mean_layer(outputs)
|
||||
|
||||
# Compute standard deviations
|
||||
if self.fixed_std is None:
|
||||
if self.std_parameterization == "exp":
|
||||
log_stds = self.std_layer(outputs)
|
||||
stds = torch.exp(log_stds)
|
||||
elif self.std_parameterization == "softplus":
|
||||
stds = torch.nn.functional.softplus(self.std_layer(outputs))
|
||||
elif self.std_parameterization == "uniform":
|
||||
stds = torch.exp(self.log_stds).expand_as(means)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid std_parameterization: {self.std_parameterization}"
|
||||
)
|
||||
else:
|
||||
assert self.std_parameterization == "fixed"
|
||||
stds = self.fixed_std.expand_as(means)
|
||||
|
||||
# Clip standard deviations and scale with temperature
|
||||
temperature = torch.tensor(temperature, device=self.device)
|
||||
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)
|
||||
|
||||
# Create distribution
|
||||
if self.tanh_squash_distribution and not non_squash_distribution:
|
||||
distribution = TanhMultivariateNormalDiag(
|
||||
loc=means,
|
||||
scale_diag=stds,
|
||||
)
|
||||
else:
|
||||
distribution = torch.distributions.Normal(
|
||||
loc=means,
|
||||
scale=stds,
|
||||
)
|
||||
|
||||
return distribution
|
||||
|
||||
def get_features(self, observations: torch.Tensor) -> torch.Tensor:
|
||||
"""Get encoded features from observations"""
|
||||
observations = observations.to(self.device)
|
||||
if self.encoder is not None:
|
||||
with torch.no_grad():
|
||||
return self.encoder(observations, train=False)
|
||||
return observations
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
"""Encode image and/or state vector observations.
|
||||
TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SACConfig):
|
||||
"""
|
||||
Creates encoders for pixel and/or state modalities.
|
||||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if "observation.image" in config.input_shapes:
|
||||
self.image_enc_layers = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2),
|
||||
nn.ReLU(),
|
||||
)
|
||||
dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"])
|
||||
with torch.inference_mode():
|
||||
out_shape = self.image_enc_layers(dummy_batch).shape[1:]
|
||||
self.image_enc_layers.extend(
|
||||
nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(np.prod(out_shape), config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
)
|
||||
if "observation.state" in config.input_shapes:
|
||||
self.state_enc_layers = nn.Sequential(
|
||||
nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
if "observation.environment_state" in config.input_shapes:
|
||||
self.env_state_enc_layers = nn.Sequential(
|
||||
nn.Linear(
|
||||
config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim
|
||||
),
|
||||
nn.ELU(),
|
||||
nn.Linear(config.state_encoder_hidden_dim, config.latent_dim),
|
||||
nn.LayerNorm(config.latent_dim),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, obs_dict: dict[str, Tensor]) -> Tensor:
|
||||
"""Encode the image and/or state vector.
|
||||
|
||||
Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken
|
||||
over all features.
|
||||
"""
|
||||
feat = []
|
||||
# Concatenate all images along the channel dimension.
|
||||
image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")]
|
||||
for image_key in image_keys:
|
||||
feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key]))
|
||||
if "observation.environment_state" in self.config.input_shapes:
|
||||
feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"]))
|
||||
if "observation.state" in self.config.input_shapes:
|
||||
feat.append(self.state_enc_layers(obs_dict["observation.state"]))
|
||||
return torch.stack(feat, dim=0).mean(0)
|
||||
|
||||
|
||||
class LagrangeMultiplier(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
init_value: float = 1.0,
|
||||
constraint_shape: Sequence[int] = (),
|
||||
device: str = "cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.device = torch.device(device)
|
||||
init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1)
|
||||
|
||||
# Initialize the Lagrange multiplier as a parameter
|
||||
self.lagrange = nn.Parameter(
|
||||
torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device)
|
||||
)
|
||||
|
||||
self.to(self.device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
lhs: Optional[torch.Tensor] = None,
|
||||
rhs: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
# Get the multiplier value based on parameterization
|
||||
multiplier = torch.nn.functional.softplus(self.lagrange)
|
||||
|
||||
# Return the raw multiplier if no constraint values provided
|
||||
if lhs is None:
|
||||
return multiplier
|
||||
|
||||
# Move inputs to device
|
||||
lhs = lhs.to(self.device)
|
||||
if rhs is not None:
|
||||
rhs = rhs.to(self.device)
|
||||
|
||||
# Use the multiplier to compute the Lagrange penalty
|
||||
if rhs is None:
|
||||
rhs = torch.zeros_like(lhs, device=self.device)
|
||||
|
||||
diff = lhs - rhs
|
||||
|
||||
assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}"
|
||||
|
||||
return multiplier * diff
|
||||
|
||||
|
||||
# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
|
||||
# 1. The base distribution is a diagonal multivariate normal distribution
|
||||
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
|
||||
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
|
||||
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
|
||||
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
|
||||
def __init__(
|
||||
self,
|
||||
loc: torch.Tensor,
|
||||
scale_diag: torch.Tensor,
|
||||
low: Optional[torch.Tensor] = None,
|
||||
high: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Create base normal distribution
|
||||
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)
|
||||
|
||||
# Create list of transforms
|
||||
transforms = []
|
||||
|
||||
# Add tanh transform
|
||||
transforms.append(torch.distributions.transforms.TanhTransform())
|
||||
|
||||
# Add rescaling transform if bounds are provided
|
||||
if low is not None and high is not None:
|
||||
transforms.append(
|
||||
torch.distributions.transforms.AffineTransform(
|
||||
loc=(high + low) / 2,
|
||||
scale=(high - low) / 2
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize parent class
|
||||
super().__init__(
|
||||
base_distribution=base_distribution,
|
||||
transforms=transforms
|
||||
)
|
||||
|
||||
# Store parameters
|
||||
self.loc = loc
|
||||
self.scale_diag = scale_diag
|
||||
self.low = low
|
||||
self.high = high
|
||||
|
||||
def mode(self) -> torch.Tensor:
|
||||
"""Get the mode of the transformed distribution"""
|
||||
# The mode of a normal distribution is its mean
|
||||
mode = self.loc
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
mode = transform(mode)
|
||||
|
||||
return mode
|
||||
|
||||
def rsample(self, sample_shape=torch.Size()) -> torch.Tensor:
|
||||
"""
|
||||
Reparameterized sample from the distribution
|
||||
"""
|
||||
# Sample from base distribution
|
||||
x = self.base_dist.rsample(sample_shape)
|
||||
|
||||
# Apply transforms
|
||||
for transform in self.transforms:
|
||||
x = transform(x)
|
||||
|
||||
return x
|
||||
|
||||
def log_prob(self, value: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute log probability of a value
|
||||
Includes the log det jacobian for the transforms
|
||||
"""
|
||||
# Initialize log prob
|
||||
log_prob = torch.zeros_like(value[..., 0])
|
||||
|
||||
# Inverse transforms to get back to normal distribution
|
||||
q = value
|
||||
for transform in reversed(self.transforms):
|
||||
q = transform.inv(q)
|
||||
log_prob = log_prob - transform.log_abs_det_jacobian(q, transform(q))
|
||||
|
||||
# Add base distribution log prob
|
||||
log_prob = log_prob + self.base_dist.log_prob(q).sum(-1)
|
||||
|
||||
return log_prob
|
||||
|
||||
def sample_and_log_prob(self, sample_shape=torch.Size()) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Sample from the distribution and compute log probability
|
||||
"""
|
||||
x = self.rsample(sample_shape)
|
||||
log_prob = self.log_prob(x)
|
||||
return x, log_prob
|
||||
|
||||
def entropy(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute entropy of the distribution
|
||||
"""
|
||||
# Start with base distribution entropy
|
||||
entropy = self.base_dist.entropy().sum(-1)
|
||||
|
||||
# Add log det jacobian for each transform
|
||||
x = self.rsample()
|
||||
for transform in self.transforms:
|
||||
entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
|
||||
x = transform(x)
|
||||
|
||||
return entropy
|
||||
|
||||
|
||||
def create_critic_ensemble(critic_class, num_critics: int, device: str = "cuda") -> nn.ModuleList:
|
||||
"""Creates an ensemble of critic networks"""
|
||||
critics = nn.ModuleList([critic_class() for _ in range(num_critics)])
|
||||
return critics.to(device)
|
||||
|
||||
|
||||
def orthogonal_init():
|
||||
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)
|
||||
|
||||
|
||||
# borrowed from tdmpc
|
||||
def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor:
|
||||
"""Helper to temporarily flatten extra dims at the start of the image tensor.
|
||||
|
||||
Args:
|
||||
fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return
|
||||
(B, *), where * is any number of dimensions.
|
||||
image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and
|
||||
can be more than 1 dimensions, generally different from *.
|
||||
Returns:
|
||||
A return value from the callable reshaped to (**, *).
|
||||
"""
|
||||
if image_tensor.ndim == 4:
|
||||
return fn(image_tensor)
|
||||
start_dims = image_tensor.shape[:-3]
|
||||
inp = torch.flatten(image_tensor, end_dim=-4)
|
||||
flat_out = fn(inp)
|
||||
return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:]))
|
||||
|
||||
@@ -120,14 +120,22 @@ def predict_action(observation, policy, device, use_amp):
|
||||
return action
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
def init_keyboard_listener(assign_rewards=False):
|
||||
"""
|
||||
Initializes a keyboard listener to enable early termination of an episode
|
||||
or environment reset by pressing the right arrow key ('->'). This may require
|
||||
sudo permissions to allow the terminal to monitor keyboard events.
|
||||
|
||||
Args:
|
||||
assign_rewards (bool): If True, allows annotating the collected trajectory
|
||||
with a binary reward at the end of the episode to indicate success.
|
||||
"""
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
if assign_rewards:
|
||||
events["next.reward"] = 0
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
@@ -152,6 +160,13 @@ def init_keyboard_listener():
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif assign_rewards and key == keyboard.Key.space:
|
||||
events["next.reward"] = 1 if events["next.reward"] == 0 else 0
|
||||
print(
|
||||
"Space key pressed. Assigning new reward to the subsequent frames. New reward:",
|
||||
events["next.reward"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
@@ -272,6 +287,8 @@ def control_loop(
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action}
|
||||
if "next.reward" in events:
|
||||
frame["next.reward"] = events["next.reward"]
|
||||
dataset.add_frame(frame)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
@@ -301,6 +318,8 @@ def reset_environment(robot, events, reset_time_s):
|
||||
|
||||
timestamp = 0
|
||||
start_vencod_t = time.perf_counter()
|
||||
if "next.reward" in events:
|
||||
events["next.reward"] = 0
|
||||
|
||||
# Wait if necessary
|
||||
with tqdm.tqdm(total=reset_time_s, desc="Waiting") as pbar:
|
||||
|
||||
48
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
48
lerobot/configs/policy/hilserl_classifier.yaml
Normal file
@@ -0,0 +1,48 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- _self_
|
||||
|
||||
seed: 13
|
||||
dataset_repo_id: "dataset_repo_id"
|
||||
train_split_proportion: 0.8
|
||||
|
||||
# Required by logger
|
||||
env:
|
||||
name: "classifier"
|
||||
task: "binary_classification"
|
||||
|
||||
|
||||
training:
|
||||
num_epochs: 5
|
||||
batch_size: 16
|
||||
learning_rate: 1e-4
|
||||
num_workers: 4
|
||||
grad_clip_norm: 10
|
||||
use_amp: true
|
||||
log_freq: 1
|
||||
eval_freq: 1 # How often to run validation (in epochs)
|
||||
save_freq: 1 # How often to save checkpoints (in epochs)
|
||||
save_checkpoint: true
|
||||
image_key: "observation.images.phone"
|
||||
label_key: "next.reward"
|
||||
|
||||
eval:
|
||||
batch_size: 16
|
||||
num_samples_to_log: 30 # Number of validation samples to log in the table
|
||||
|
||||
policy:
|
||||
name: "hilserl/classifier"
|
||||
model_name: "facebook/convnext-base-224"
|
||||
model_type: "cnn"
|
||||
|
||||
wandb:
|
||||
enable: false
|
||||
project: "classifier-training"
|
||||
entity: "wandb_entity"
|
||||
job_name: "classifier_training_0"
|
||||
disable_artifact: false
|
||||
|
||||
device: "mps"
|
||||
resume: false
|
||||
output_dir: "output"
|
||||
@@ -10,7 +10,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0031751
|
||||
port: /dev/tty.usbmodem58760430441
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl330-m077"]
|
||||
@@ -23,7 +23,7 @@ leader_arms:
|
||||
follower_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.dynamixel.DynamixelMotorsBus
|
||||
port: /dev/tty.usbmodem575E0032081
|
||||
port: /dev/tty.usbmodem585A0083391
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "xl430-w250"]
|
||||
|
||||
@@ -18,7 +18,7 @@ max_relative_target: null
|
||||
leader_arms:
|
||||
main:
|
||||
_target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus
|
||||
port: /dev/tty.usbmodem585A0077581
|
||||
port: /dev/tty.usbmodem58760433331
|
||||
motors:
|
||||
# name: (index, model)
|
||||
shoulder_pan: [1, "sts3215"]
|
||||
|
||||
@@ -191,6 +191,7 @@ def record(
|
||||
single_task: str,
|
||||
pretrained_policy_name_or_path: str | None = None,
|
||||
policy_overrides: List[str] | None = None,
|
||||
assign_rewards: bool = False,
|
||||
fps: int | None = None,
|
||||
warmup_time_s: int | float = 2,
|
||||
episode_time_s: int | float = 10,
|
||||
@@ -214,6 +215,9 @@ def record(
|
||||
policy = None
|
||||
device = None
|
||||
use_amp = None
|
||||
extra_features = (
|
||||
{"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None
|
||||
)
|
||||
|
||||
if single_task:
|
||||
task = single_task
|
||||
@@ -254,12 +258,12 @@ def record(
|
||||
use_videos=video,
|
||||
image_writer_processes=num_image_writer_processes,
|
||||
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
features=extra_features,
|
||||
)
|
||||
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
listener, events = init_keyboard_listener(assign_rewards=assign_rewards)
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
@@ -300,7 +304,7 @@ def record(
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
(dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", play_sounds)
|
||||
reset_environment(robot, events, reset_time_s)
|
||||
@@ -469,12 +473,12 @@ if __name__ == "__main__":
|
||||
default=1,
|
||||
help="Upload dataset to Hugging Face hub.",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--tags",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Add tags to your dataset on the hub.",
|
||||
)
|
||||
# parser_record.add_argument(
|
||||
# "--tags",
|
||||
# type=str,
|
||||
# nargs="*",
|
||||
# help="Add tags to your dataset on the hub.",
|
||||
# )
|
||||
parser_record.add_argument(
|
||||
"--num-image-writer-processes",
|
||||
type=int,
|
||||
@@ -517,6 +521,12 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
parser_record.add_argument(
|
||||
"--assign-rewards",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.",
|
||||
)
|
||||
|
||||
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
|
||||
parser_replay.add_argument(
|
||||
|
||||
335
lerobot/scripts/eval_on_robot.py
Normal file
335
lerobot/scripts/eval_on_robot.py
Normal file
@@ -0,0 +1,335 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Evaluate a policy by running rollouts on the real robot and computing metrics.
|
||||
|
||||
Usage examples: evaluate a checkpoint from the LeRobot training script for 10 episodes.
|
||||
|
||||
```
|
||||
python lerobot/scripts/eval_on_robot.py \
|
||||
-p outputs/train/model/checkpoints/005000/pretrained_model \
|
||||
eval.n_episodes=10
|
||||
```
|
||||
|
||||
**NOTE** (michel-aractingi): This script is incomplete and it is being prepared
|
||||
for running training on the real robot.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import trange
|
||||
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.robot_devices.control_utils import busy_wait, is_headless
|
||||
from lerobot.common.robot_devices.robots.factory import Robot, make_robot
|
||||
from lerobot.common.utils.utils import (
|
||||
init_hydra_config,
|
||||
init_logging,
|
||||
log_say,
|
||||
)
|
||||
|
||||
|
||||
def rollout(robot: Robot, policy: Policy, fps: int, control_time_s: float = 20, use_amp: bool = True) -> dict:
|
||||
"""Run a batched policy rollout on the real robot.
|
||||
|
||||
The return dictionary contains:
|
||||
"robot": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation
|
||||
keys. NOTE the that this has an extra sequence element relative to the other keys in the
|
||||
dictionary. This is because an extra observation is included for after the environment is
|
||||
terminated or truncated.
|
||||
"action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not
|
||||
including the last observations).
|
||||
"reward": A (batch, sequence) tensor of rewards received for applying the actions.
|
||||
"success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon
|
||||
environment termination/truncation).
|
||||
"done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element,
|
||||
the first True is followed by True's all the way till the end. This can be used for masking
|
||||
extraneous elements from the sequences above.
|
||||
|
||||
Args:
|
||||
robot: The robot class that defines the interface with the real robot.
|
||||
policy: The policy. Must be a PyTorch nn module.
|
||||
|
||||
Returns:
|
||||
The dictionary described above.
|
||||
"""
|
||||
# assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module."
|
||||
# device = get_device_from_parameters(policy)
|
||||
|
||||
# define keyboard listener
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Reset the policy. TODO (michel-aractingi) add real policy evaluation once the code is ready.
|
||||
# policy.reset()
|
||||
|
||||
# Get observation from real robot
|
||||
observation = robot.capture_observation()
|
||||
|
||||
# Calculate reward. TODO (michel-aractingi)
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
all_observations = []
|
||||
all_actions = []
|
||||
all_rewards = []
|
||||
all_successes = []
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
timestamp = 0.0
|
||||
while timestamp < control_time_s:
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
all_observations.append(deepcopy(observation))
|
||||
# observation = {key: observation[key].to(device, non_blocking=True) for key in observation}
|
||||
|
||||
# Apply the next action.
|
||||
while events["pause_policy"] and not events["human_intervention_step"]:
|
||||
busy_wait(0.5)
|
||||
|
||||
if events["human_intervention_step"]:
|
||||
# take over the robot's actions
|
||||
observation, action = robot.teleop_step(record_data=True)
|
||||
action = action["action"] # teleop step returns torch tensors but in a dict
|
||||
else:
|
||||
# explore with policy
|
||||
with torch.inference_mode():
|
||||
action = robot.follower_arms["main"].read("Present_Position")
|
||||
action = torch.from_numpy(action)
|
||||
robot.send_action(action)
|
||||
# action = predict_action(observation, policy, device, use_amp)
|
||||
|
||||
observation = robot.capture_observation()
|
||||
# Calculate reward
|
||||
# in HIL-SERL it will be with a reward classifier
|
||||
reward = calculate_reward(observation)
|
||||
|
||||
all_actions.append(action)
|
||||
all_rewards.append(torch.from_numpy(reward))
|
||||
all_successes.append(torch.tensor([False]))
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
events["human_intervention_step"] = False
|
||||
events["pause_policy"] = False
|
||||
break
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
dones = torch.tensor([False] * len(all_actions))
|
||||
dones[-1] = True
|
||||
# Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors.
|
||||
ret = {
|
||||
"action": torch.stack(all_actions, dim=1),
|
||||
"next.reward": torch.stack(all_rewards, dim=1),
|
||||
"next.success": torch.stack(all_successes, dim=1),
|
||||
"done": dones,
|
||||
}
|
||||
stacked_observations = {}
|
||||
for key in all_observations[0]:
|
||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||
ret["observation"] = stacked_observations
|
||||
|
||||
listener.stop()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def eval_policy(
|
||||
robot: Robot,
|
||||
policy: torch.nn.Module,
|
||||
fps: float,
|
||||
n_episodes: int,
|
||||
control_time_s: int = 20,
|
||||
use_amp: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Args:
|
||||
env: The batch of environments.
|
||||
policy: The policy.
|
||||
n_episodes: The number of episodes to evaluate.
|
||||
Returns:
|
||||
Dictionary with metrics and data regarding the rollouts.
|
||||
"""
|
||||
# TODO (michel-aractingi) comment this out for testing with a fixed policy
|
||||
# assert isinstance(policy, Policy)
|
||||
# policy.eval()
|
||||
|
||||
sum_rewards = []
|
||||
max_rewards = []
|
||||
successes = []
|
||||
rollouts = []
|
||||
|
||||
start_eval = time.perf_counter()
|
||||
progbar = trange(n_episodes, desc="Evaluating policy on real robot")
|
||||
for _batch_idx in progbar:
|
||||
rollout_data = rollout(robot, policy, fps, control_time_s, use_amp)
|
||||
|
||||
rollouts.append(rollout_data)
|
||||
sum_rewards.append(sum(rollout_data["next.reward"]))
|
||||
max_rewards.append(max(rollout_data["next.reward"]))
|
||||
successes.append(rollout_data["next.success"][-1])
|
||||
|
||||
info = {
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
"sum_reward": sum_reward,
|
||||
"max_reward": max_reward,
|
||||
"pc_success": success * 100,
|
||||
}
|
||||
for i, (sum_reward, max_reward, success) in enumerate(
|
||||
zip(
|
||||
sum_rewards[:n_episodes],
|
||||
max_rewards[:n_episodes],
|
||||
successes[:n_episodes],
|
||||
strict=False,
|
||||
)
|
||||
)
|
||||
],
|
||||
"aggregated": {
|
||||
"avg_sum_reward": float(np.nanmean(torch.cat(sum_rewards[:n_episodes]))),
|
||||
"avg_max_reward": float(np.nanmean(torch.cat(max_rewards[:n_episodes]))),
|
||||
"pc_success": float(np.nanmean(torch.cat(successes[:n_episodes])) * 100),
|
||||
"eval_s": time.time() - start_eval,
|
||||
"eval_ep_s": (time.time() - start_eval) / n_episodes,
|
||||
},
|
||||
}
|
||||
|
||||
if robot.is_connected:
|
||||
robot.disconnect()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def calculate_reward(observation):
|
||||
"""
|
||||
Method to calculate reward function in some way.
|
||||
In HIL-SERL this is done through defining a reward classifier
|
||||
"""
|
||||
# reward = reward_classifier(observation)
|
||||
return np.array([0.0])
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["pause_policy"] = False
|
||||
events["human_intervention_step"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.space:
|
||||
# check if first space press then pause the policy for the user to get ready
|
||||
# if second space press then the user is ready to start intervention
|
||||
if not events["pause_policy"]:
|
||||
print(
|
||||
"Space key pressed. Human intervention required.\n"
|
||||
"Place the leader in similar pose to the follower and press space again."
|
||||
)
|
||||
events["pause_policy"] = True
|
||||
log_say("Human intervention stage. Get ready to take over.", play_sounds=True)
|
||||
else:
|
||||
events["human_intervention_step"] = True
|
||||
print("Space key pressed. Human intervention starting.")
|
||||
log_say("Starting human intervention.", play_sounds=True)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--robot-path",
|
||||
type=str,
|
||||
default="lerobot/configs/robot/koch.yaml",
|
||||
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--robot-overrides",
|
||||
type=str,
|
||||
nargs="*",
|
||||
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
|
||||
)
|
||||
group.add_argument(
|
||||
"-p",
|
||||
"--pretrained-policy-name-or-path",
|
||||
help=(
|
||||
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
|
||||
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
|
||||
"(useful for debugging). This argument is mutually exclusive with `--config`."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
help=(
|
||||
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
|
||||
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
help=(
|
||||
"Where to save the evaluation outputs. If not provided, outputs are saved in "
|
||||
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides)
|
||||
robot = make_robot(robot_cfg)
|
||||
if not robot.is_connected:
|
||||
robot.connect()
|
||||
|
||||
eval_policy(robot, None, fps=40, n_episodes=2, control_time_s=100)
|
||||
310
lerobot/scripts/train_hilserl_classifier.py
Normal file
310
lerobot/scripts/train_hilserl_classifier.py
Normal file
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
import hydra
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import wandb
|
||||
from deepdiff import DeepDiff
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from termcolor import colored
|
||||
from torch import optim
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.common.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
init_hydra_config,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
|
||||
def get_model(cfg, logger):
|
||||
classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg)
|
||||
model = Classifier(classifier_config)
|
||||
if cfg.resume:
|
||||
model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict())
|
||||
return model
|
||||
|
||||
|
||||
def create_balanced_sampler(dataset, cfg):
|
||||
# Creates a weighted sampler to handle class imbalance
|
||||
|
||||
labels = torch.tensor([item[cfg.training.label_key] for item in dataset])
|
||||
_, counts = torch.unique(labels, return_counts=True)
|
||||
class_weights = 1.0 / counts.float()
|
||||
sample_weights = class_weights[labels]
|
||||
|
||||
return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
|
||||
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg):
|
||||
# Single epoch training loop with AMP support and progress tracking
|
||||
model.train()
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
pbar = tqdm(train_loader, desc="Training")
|
||||
for batch_idx, batch in enumerate(pbar):
|
||||
start_time = time.perf_counter()
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
# Forward pass with optional AMP
|
||||
with torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Backward pass with gradient scaling if AMP enabled
|
||||
optimizer.zero_grad()
|
||||
if cfg.training.use_amp:
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
|
||||
current_acc = 100 * correct / total
|
||||
train_info = {
|
||||
"loss": loss.item(),
|
||||
"accuracy": current_acc,
|
||||
"dataloading_s": time.perf_counter() - start_time,
|
||||
}
|
||||
|
||||
logger.log_dict(train_info, step + batch_idx, mode="train")
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"})
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8):
|
||||
# Validation loop with metric tracking and sample logging
|
||||
model.eval()
|
||||
correct = 0
|
||||
total = 0
|
||||
batch_start_time = time.perf_counter()
|
||||
samples = []
|
||||
running_loss = 0
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.training.use_amp else nullcontext():
|
||||
for batch in tqdm(val_loader, desc="Validation"):
|
||||
images = batch[cfg.training.image_key].to(device)
|
||||
labels = batch[cfg.training.label_key].float().to(device)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs.logits, labels)
|
||||
|
||||
# Track metrics
|
||||
if model.config.num_classes == 2:
|
||||
predictions = (torch.sigmoid(outputs.logits) > 0.5).float()
|
||||
else:
|
||||
predictions = torch.argmax(outputs.logits, dim=1)
|
||||
correct += (predictions == labels).sum().item()
|
||||
total += labels.size(0)
|
||||
running_loss += loss.item()
|
||||
|
||||
# Log sample predictions for visualization
|
||||
if len(samples) < num_samples_to_log:
|
||||
for i in range(min(num_samples_to_log - len(samples), len(images))):
|
||||
if model.config.num_classes == 2:
|
||||
confidence = round(outputs.probabilities[i].item(), 3)
|
||||
else:
|
||||
confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()]
|
||||
samples.append(
|
||||
{
|
||||
"image": wandb.Image(images[i].cpu()),
|
||||
"true_label": labels[i].item(),
|
||||
"predicted": predictions[i].item(),
|
||||
"confidence": confidence,
|
||||
}
|
||||
)
|
||||
|
||||
accuracy = 100 * correct / total
|
||||
avg_loss = running_loss / len(val_loader)
|
||||
|
||||
eval_info = {
|
||||
"loss": avg_loss,
|
||||
"accuracy": accuracy,
|
||||
"eval_s": time.perf_counter() - batch_start_time,
|
||||
"eval/prediction_samples": wandb.Table(
|
||||
data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples],
|
||||
columns=["Image", "True Label", "Predicted", "Confidence"],
|
||||
)
|
||||
if logger._cfg.wandb.enable
|
||||
else None,
|
||||
}
|
||||
|
||||
return accuracy, eval_info
|
||||
|
||||
|
||||
@hydra.main(version_base="1.2", config_path="../configs", config_name="hilserl_classifier")
|
||||
def train(cfg: DictConfig) -> None:
|
||||
# Main training pipeline with support for resuming training
|
||||
logging.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# Initialize training environment
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
set_global_seed(cfg.seed)
|
||||
|
||||
out_dir = Path(cfg.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None)
|
||||
|
||||
# Setup dataset and dataloaders
|
||||
dataset = LeRobotDataset(cfg.dataset_repo_id)
|
||||
logging.info(f"Dataset size: {len(dataset)}")
|
||||
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
sampler = create_balanced_sampler(train_dataset, cfg)
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=cfg.training.batch_size,
|
||||
num_workers=cfg.training.num_workers,
|
||||
sampler=sampler,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=cfg.eval.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=cfg.training.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
# Resume training if requested
|
||||
step = 0
|
||||
best_val_acc = 0
|
||||
|
||||
if cfg.resume:
|
||||
if not Logger.get_last_checkpoint_dir(out_dir).exists():
|
||||
raise RuntimeError(
|
||||
"You have set resume=True, but there is no model checkpoint in "
|
||||
f"{Logger.get_last_checkpoint_dir(out_dir)}"
|
||||
)
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
logging.info(
|
||||
colored(
|
||||
"You have set resume=True, indicating that you wish to resume a run",
|
||||
color="yellow",
|
||||
attrs=["bold"],
|
||||
)
|
||||
)
|
||||
# Load and validate checkpoint configuration
|
||||
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
|
||||
# Check for differences between the checkpoint configuration and provided configuration.
|
||||
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
|
||||
resolve_delta_timestamps(cfg)
|
||||
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
|
||||
# Ignore the `resume` and parameters.
|
||||
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
|
||||
del diff["values_changed"]["root['resume']"]
|
||||
if len(diff) > 0:
|
||||
logging.warning(
|
||||
"At least one difference was detected between the checkpoint configuration and "
|
||||
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
|
||||
"takes precedence.",
|
||||
)
|
||||
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
|
||||
cfg = checkpoint_cfg
|
||||
cfg.resume = True
|
||||
|
||||
# Initialize model and training components
|
||||
model = get_model(cfg=cfg, logger=logger).to(device)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate)
|
||||
# Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class
|
||||
criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss()
|
||||
grad_scaler = GradScaler(enabled=cfg.training.use_amp)
|
||||
|
||||
# Log model parameters
|
||||
num_learnable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in model.parameters())
|
||||
logging.info(f"Learnable parameters: {format_big_number(num_learnable_params)}")
|
||||
logging.info(f"Total parameters: {format_big_number(num_total_params)}")
|
||||
|
||||
if cfg.resume:
|
||||
step = logger.load_last_training_state(optimizer, None)
|
||||
|
||||
# Training loop with validation and checkpointing
|
||||
for epoch in range(cfg.training.num_epochs):
|
||||
logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}")
|
||||
|
||||
train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg)
|
||||
|
||||
# Periodic validation
|
||||
if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0:
|
||||
val_acc, eval_info = validate(
|
||||
model,
|
||||
val_loader,
|
||||
criterion,
|
||||
device,
|
||||
logger,
|
||||
cfg,
|
||||
)
|
||||
logger.log_dict(eval_info, step + len(train_loader), mode="eval")
|
||||
|
||||
# Save best model
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier="best",
|
||||
)
|
||||
|
||||
# Periodic checkpointing
|
||||
if cfg.training.save_checkpoint and (epoch + 1) % cfg.training.save_freq == 0:
|
||||
logger.save_checkpoint(
|
||||
train_step=step + len(train_loader),
|
||||
policy=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=None,
|
||||
identifier=f"{epoch+1:06d}",
|
||||
)
|
||||
|
||||
step += len(train_loader)
|
||||
|
||||
logging.info("Training completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
@@ -53,29 +53,20 @@ python lerobot/scripts/visualize_dataset_html.py \
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import requests
|
||||
from flask import Flask, redirect, render_template, request, url_for
|
||||
import tqdm
|
||||
from flask import Flask, redirect, render_template, url_for
|
||||
|
||||
from lerobot import available_datasets
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import IterableNamespace
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
|
||||
def run_server(
|
||||
dataset: LeRobotDataset | IterableNamespace | None,
|
||||
episodes: list[int] | None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int],
|
||||
host: str,
|
||||
port: str,
|
||||
static_folder: Path,
|
||||
@@ -85,50 +76,10 @@ def run_server(
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache
|
||||
|
||||
@app.route("/")
|
||||
def hommepage(dataset=dataset):
|
||||
if dataset:
|
||||
dataset_namespace, dataset_name = dataset.repo_id.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=0,
|
||||
)
|
||||
)
|
||||
|
||||
dataset_param, episode_param = None, None
|
||||
all_params = request.args
|
||||
if "dataset" in all_params:
|
||||
dataset_param = all_params["dataset"]
|
||||
if "episode" in all_params:
|
||||
episode_param = int(all_params["episode"])
|
||||
|
||||
if dataset_param:
|
||||
dataset_namespace, dataset_name = dataset_param.split("/")
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
dataset_namespace=dataset_namespace,
|
||||
dataset_name=dataset_name,
|
||||
episode_id=episode_param if episode_param is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
featured_datasets = [
|
||||
"lerobot/aloha_static_cups_open",
|
||||
"lerobot/columbia_cairlab_pusht_real",
|
||||
"lerobot/taco_play",
|
||||
]
|
||||
return render_template(
|
||||
"visualize_dataset_homepage.html",
|
||||
featured_datasets=featured_datasets,
|
||||
lerobot_datasets=available_datasets,
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>")
|
||||
def show_first_episode(dataset_namespace, dataset_name):
|
||||
first_episode_id = 0
|
||||
def index():
|
||||
# home page redirects to the first episode page
|
||||
[dataset_namespace, dataset_name] = dataset.repo_id.split("/")
|
||||
first_episode_id = episodes[0]
|
||||
return redirect(
|
||||
url_for(
|
||||
"show_episode",
|
||||
@@ -139,85 +90,30 @@ def run_server(
|
||||
)
|
||||
|
||||
@app.route("/<string:dataset_namespace>/<string:dataset_name>/episode_<int:episode_id>")
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes):
|
||||
repo_id = f"{dataset_namespace}/{dataset_name}"
|
||||
try:
|
||||
if dataset is None:
|
||||
dataset = get_dataset_info(repo_id)
|
||||
except FileNotFoundError:
|
||||
return (
|
||||
"Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461",
|
||||
400,
|
||||
)
|
||||
dataset_version = (
|
||||
dataset.meta._version if isinstance(dataset, LeRobotDataset) else dataset.codebase_version
|
||||
)
|
||||
match = re.search(r"v(\d+)\.", dataset_version)
|
||||
if match:
|
||||
major_version = int(match.group(1))
|
||||
if major_version < 2:
|
||||
return "Make sure to convert your LeRobotDataset to v2 & above."
|
||||
|
||||
episode_data_csv_str, columns = get_episode_data(dataset, episode_id)
|
||||
def show_episode(dataset_namespace, dataset_name, episode_id):
|
||||
dataset_info = {
|
||||
"repo_id": f"{dataset_namespace}/{dataset_name}",
|
||||
"num_samples": dataset.num_frames
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_frames,
|
||||
"num_episodes": dataset.num_episodes
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.total_episodes,
|
||||
"repo_id": dataset.repo_id,
|
||||
"num_samples": dataset.num_frames,
|
||||
"num_episodes": dataset.num_episodes,
|
||||
"fps": dataset.fps,
|
||||
}
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
video_paths = [
|
||||
dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys
|
||||
]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.parent.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
else:
|
||||
video_keys = [key for key, ft in dataset.features.items() if ft["dtype"] == "video"]
|
||||
videos_info = [
|
||||
{
|
||||
"url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/"
|
||||
+ dataset.video_path.format(
|
||||
episode_chunk=int(episode_id) // dataset.chunks_size,
|
||||
video_key=video_key,
|
||||
episode_index=episode_id,
|
||||
),
|
||||
"filename": video_key,
|
||||
}
|
||||
for video_key in video_keys
|
||||
]
|
||||
|
||||
response = requests.get(
|
||||
f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl"
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Split into lines and parse each line as JSON
|
||||
tasks_jsonl = [json.loads(line) for line in response.text.splitlines() if line.strip()]
|
||||
|
||||
filtered_tasks_jsonl = [row for row in tasks_jsonl if row["episode_index"] == episode_id]
|
||||
tasks = filtered_tasks_jsonl[0]["tasks"]
|
||||
|
||||
video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys]
|
||||
tasks = dataset.meta.episodes[episode_id]["tasks"]
|
||||
videos_info = [
|
||||
{"url": url_for("static", filename=video_path), "filename": video_path.name}
|
||||
for video_path in video_paths
|
||||
]
|
||||
videos_info[0]["language_instruction"] = tasks
|
||||
|
||||
if episodes is None:
|
||||
episodes = list(
|
||||
range(dataset.num_episodes if isinstance(dataset, LeRobotDataset) else dataset.total_episodes)
|
||||
)
|
||||
|
||||
ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id))
|
||||
return render_template(
|
||||
"visualize_dataset_template.html",
|
||||
episode_id=episode_id,
|
||||
episodes=episodes,
|
||||
dataset_info=dataset_info,
|
||||
videos_info=videos_info,
|
||||
episode_data_csv_str=episode_data_csv_str,
|
||||
columns=columns,
|
||||
ep_csv_url=ep_csv_url,
|
||||
has_policy=False,
|
||||
)
|
||||
|
||||
app.run(host=host, port=port)
|
||||
@@ -228,69 +124,46 @@ def get_ep_csv_fname(episode_id: int):
|
||||
return ep_csv_fname
|
||||
|
||||
|
||||
def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index):
|
||||
"""Get a csv str containing timeseries data of an episode (e.g. state and action).
|
||||
def write_episode_data_csv(output_dir, file_name, episode_index, dataset):
|
||||
"""Write a csv file containg timeseries data of an episode (e.g. state and action).
|
||||
This file will be loaded by Dygraph javascript to plot data in real time."""
|
||||
columns = []
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
|
||||
selected_columns = [col for col, ft in dataset.features.items() if ft["dtype"] == "float32"]
|
||||
selected_columns.remove("timestamp")
|
||||
has_state = "observation.state" in dataset.features
|
||||
has_action = "action" in dataset.features
|
||||
|
||||
# init header of csv with state and action names
|
||||
header = ["timestamp"]
|
||||
if has_state:
|
||||
dim_state = dataset.meta.shapes["observation.state"][0]
|
||||
header += [f"state_{i}" for i in range(dim_state)]
|
||||
if has_action:
|
||||
dim_action = dataset.meta.shapes["action"][0]
|
||||
header += [f"action_{i}" for i in range(dim_action)]
|
||||
|
||||
for column_name in selected_columns:
|
||||
dim_state = (
|
||||
dataset.meta.shapes[column_name][0]
|
||||
if isinstance(dataset, LeRobotDataset)
|
||||
else dataset.features[column_name].shape[0]
|
||||
)
|
||||
header += [f"{column_name}_{i}" for i in range(dim_state)]
|
||||
columns = ["timestamp"]
|
||||
if has_state:
|
||||
columns += ["observation.state"]
|
||||
if has_action:
|
||||
columns += ["action"]
|
||||
|
||||
if "names" in dataset.features[column_name] and dataset.features[column_name]["names"]:
|
||||
column_names = dataset.features[column_name]["names"]
|
||||
while not isinstance(column_names, list):
|
||||
column_names = list(column_names.values())[0]
|
||||
else:
|
||||
column_names = [f"motor_{i}" for i in range(dim_state)]
|
||||
columns.append({"key": column_name, "value": column_names})
|
||||
rows = []
|
||||
data = dataset.hf_dataset.select_columns(columns)
|
||||
for i in range(from_idx, to_idx):
|
||||
row = [data[i]["timestamp"].item()]
|
||||
if has_state:
|
||||
row += data[i]["observation.state"].tolist()
|
||||
if has_action:
|
||||
row += data[i]["action"].tolist()
|
||||
rows.append(row)
|
||||
|
||||
selected_columns.insert(0, "timestamp")
|
||||
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
from_idx = dataset.episode_data_index["from"][episode_index]
|
||||
to_idx = dataset.episode_data_index["to"][episode_index]
|
||||
data = (
|
||||
dataset.hf_dataset.select(range(from_idx, to_idx))
|
||||
.select_columns(selected_columns)
|
||||
.with_format("pandas")
|
||||
)
|
||||
else:
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + dataset.data_path.format(
|
||||
episode_chunk=int(episode_index) // dataset.chunks_size, episode_index=episode_index
|
||||
)
|
||||
df = pd.read_parquet(url)
|
||||
data = df[selected_columns] # Select specific columns
|
||||
|
||||
rows = np.hstack(
|
||||
(
|
||||
np.expand_dims(data["timestamp"], axis=1),
|
||||
*[np.vstack(data[col]) for col in selected_columns[1:]],
|
||||
)
|
||||
).tolist()
|
||||
|
||||
# Convert data to CSV string
|
||||
csv_buffer = StringIO()
|
||||
csv_writer = csv.writer(csv_buffer)
|
||||
# Write header
|
||||
csv_writer.writerow(header)
|
||||
# Write data rows
|
||||
csv_writer.writerows(rows)
|
||||
csv_string = csv_buffer.getvalue()
|
||||
|
||||
return csv_string, columns
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / file_name, "w") as f:
|
||||
f.write(",".join(header) + "\n")
|
||||
for row in rows:
|
||||
row_str = [str(col) for col in row]
|
||||
f.write(",".join(row_str) + "\n")
|
||||
|
||||
|
||||
def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
@@ -302,31 +175,9 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]
|
||||
]
|
||||
|
||||
|
||||
def get_episode_language_instruction(dataset: LeRobotDataset, ep_index: int) -> list[str]:
|
||||
# check if the dataset has language instructions
|
||||
if "language_instruction" not in dataset.features:
|
||||
return None
|
||||
|
||||
# get first frame index
|
||||
first_frame_idx = dataset.episode_data_index["from"][ep_index].item()
|
||||
|
||||
language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"]
|
||||
# TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored
|
||||
# with the tf.tensor appearing in the string
|
||||
return language_instruction.removeprefix("tf.Tensor(b'").removesuffix("', shape=(), dtype=string)")
|
||||
|
||||
|
||||
def get_dataset_info(repo_id: str) -> IterableNamespace:
|
||||
response = requests.get(f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json")
|
||||
response.raise_for_status() # Raises an HTTPError for bad responses
|
||||
dataset_info = response.json()
|
||||
dataset_info["repo_id"] = repo_id
|
||||
return IterableNamespace(dataset_info)
|
||||
|
||||
|
||||
def visualize_dataset_html(
|
||||
dataset: LeRobotDataset | None,
|
||||
episodes: list[int] | None = None,
|
||||
dataset: LeRobotDataset,
|
||||
episodes: list[int] = None,
|
||||
output_dir: Path | None = None,
|
||||
serve: bool = True,
|
||||
host: str = "127.0.0.1",
|
||||
@@ -335,11 +186,11 @@ def visualize_dataset_html(
|
||||
) -> Path | None:
|
||||
init_logging()
|
||||
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
if len(dataset.meta.image_keys) > 0:
|
||||
raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.")
|
||||
|
||||
if output_dir is None:
|
||||
# Create a temporary directory that will be automatically cleaned up
|
||||
output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_")
|
||||
output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}"
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
if output_dir.exists():
|
||||
@@ -350,29 +201,28 @@ def visualize_dataset_html(
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
static_dir = output_dir / "static"
|
||||
static_dir.mkdir(parents=True, exist_ok=True)
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
|
||||
if dataset is None:
|
||||
if serve:
|
||||
run_server(
|
||||
dataset=None,
|
||||
episodes=None,
|
||||
host=host,
|
||||
port=port,
|
||||
static_folder=static_dir,
|
||||
template_folder=template_dir,
|
||||
)
|
||||
else:
|
||||
# Create a simlink from the dataset video folder containg mp4 files to the output directory
|
||||
# so that the http server can get access to the mp4 files.
|
||||
if isinstance(dataset, LeRobotDataset):
|
||||
ln_videos_dir = static_dir / "videos"
|
||||
if not ln_videos_dir.exists():
|
||||
ln_videos_dir.symlink_to((dataset.root / "videos").resolve())
|
||||
template_dir = Path(__file__).resolve().parent.parent / "templates"
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
if episodes is None:
|
||||
episodes = list(range(dataset.num_episodes))
|
||||
|
||||
logging.info("Writing CSV files")
|
||||
for episode_index in tqdm.tqdm(episodes):
|
||||
# write states and actions in a csv (it can be slow for big datasets)
|
||||
ep_csv_fname = get_ep_csv_fname(episode_index)
|
||||
# TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors?
|
||||
write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset)
|
||||
|
||||
if serve:
|
||||
run_server(dataset, episodes, host, port, static_dir, template_dir)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -381,7 +231,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -396,12 +246,6 @@ def main():
|
||||
default=None,
|
||||
help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-from-hf-hub",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Load videos and parquet files from HF Hub rather than local system.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
@@ -443,19 +287,11 @@ def main():
|
||||
args = parser.parse_args()
|
||||
kwargs = vars(args)
|
||||
repo_id = kwargs.pop("repo_id")
|
||||
load_from_hf_hub = kwargs.pop("load_from_hf_hub")
|
||||
root = kwargs.pop("root")
|
||||
local_files_only = kwargs.pop("local_files_only")
|
||||
|
||||
dataset = None
|
||||
if repo_id:
|
||||
dataset = (
|
||||
LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
if not load_from_hf_hub
|
||||
else get_dataset_info(repo_id)
|
||||
)
|
||||
|
||||
visualize_dataset_html(dataset, **vars(args))
|
||||
dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only)
|
||||
visualize_dataset_html(dataset, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Interactive Video Background Page</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<script defer src="https://cdn.jsdelivr.net/npm/alpinejs@3.x.x/dist/cdn.min.js"></script>
|
||||
</head>
|
||||
<body class="h-screen overflow-hidden font-mono text-white" x-data="{
|
||||
inputValue: '',
|
||||
navigateToDataset() {
|
||||
const trimmedValue = this.inputValue.trim();
|
||||
if (trimmedValue) {
|
||||
window.location.href = `/${trimmedValue}`;
|
||||
}
|
||||
}
|
||||
}">
|
||||
<div class="fixed inset-0 w-full h-full overflow-hidden">
|
||||
<video class="absolute min-w-full min-h-full w-auto h-auto top-1/2 left-1/2 transform -translate-x-1/2 -translate-y-1/2" autoplay muted loop>
|
||||
<source src="https://huggingface.co/datasets/cadene/koch_bimanual_folding/resolve/v1.6/videos/observation.images.phone_episode_000037.mp4" type="video/mp4">
|
||||
Your browser does not support HTML5 video.
|
||||
</video>
|
||||
</div>
|
||||
<div class="fixed inset-0 bg-black bg-opacity-80"></div>
|
||||
<div class="relative z-10 flex flex-col items-center justify-center h-screen">
|
||||
<div class="text-center mb-8">
|
||||
<h1 class="text-4xl font-bold mb-4">LeRobot Dataset Visualizer</h1>
|
||||
|
||||
<a href="https://x.com/RemiCadene/status/1825455895561859185" target="_blank" rel="noopener noreferrer" class="underline">create & train your own robots</a>
|
||||
|
||||
<p class="text-xl mb-4"></p>
|
||||
<div class="text-left inline-block">
|
||||
<h3 class="font-semibold mb-2 mt-4">Example Datasets:</h3>
|
||||
<ul class="list-disc list-inside">
|
||||
{% for dataset in featured_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
<div class="flex w-full max-w-lg px-4 mb-4">
|
||||
<input
|
||||
type="text"
|
||||
x-model="inputValue"
|
||||
@keyup.enter="navigateToDataset"
|
||||
placeholder="enter dataset id (ex: lerobot/droid_100)"
|
||||
class="flex-grow px-4 py-2 rounded-l bg-white bg-opacity-20 text-white placeholder-gray-300 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
<button
|
||||
@click="navigateToDataset"
|
||||
class="px-4 py-2 bg-blue-500 text-white rounded-r hover:bg-blue-600 focus:outline-none focus:ring-2 focus:ring-blue-300"
|
||||
>
|
||||
Go
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<details class="mt-4 max-w-full px-4">
|
||||
<summary>More example datasets</summary>
|
||||
<ul class="list-disc list-inside max-h-28 overflow-y-auto break-all">
|
||||
{% for dataset in lerobot_datasets %}
|
||||
<li><a href="/{{ dataset }}" class="text-blue-300 hover:text-blue-100 hover:underline">{{ dataset }}</a></li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</details>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -31,16 +31,11 @@
|
||||
}">
|
||||
<!-- Sidebar -->
|
||||
<div x-ref="sidebar" class="bg-slate-900 p-5 break-words overflow-y-auto shrink-0 md:shrink md:w-60 md:max-h-screen">
|
||||
<a href="https://github.com/huggingface/lerobot" target="_blank" class="hidden md:block">
|
||||
<img src="https://github.com/huggingface/lerobot/raw/main/media/lerobot-logo-thumbnail.png">
|
||||
</a>
|
||||
<a href="https://huggingface.co/datasets/{{ dataset_info.repo_id }}" target="_blank">
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
</a>
|
||||
<h1 class="mb-4 text-xl font-semibold">{{ dataset_info.repo_id }}</h1>
|
||||
|
||||
<ul>
|
||||
<li>
|
||||
Number of samples/frames: {{ dataset_info.num_samples }}
|
||||
Number of samples/frames: {{ dataset_info.num_frames }}
|
||||
</li>
|
||||
<li>
|
||||
Number of episodes: {{ dataset_info.num_episodes }}
|
||||
@@ -98,35 +93,10 @@
|
||||
</div>
|
||||
|
||||
<!-- Videos -->
|
||||
<div class="max-w-32 relative text-sm mb-4 select-none"
|
||||
@click.outside="isVideosDropdownOpen = false">
|
||||
<div
|
||||
@click="isVideosDropdownOpen = !isVideosDropdownOpen"
|
||||
class="p-2 border border-slate-500 rounded flex justify-between items-center cursor-pointer"
|
||||
>
|
||||
<span class="truncate">filter videos</span>
|
||||
<div class="transition-transform" :class="{ 'rotate-180': isVideosDropdownOpen }">🔽</div>
|
||||
</div>
|
||||
|
||||
<div x-show="isVideosDropdownOpen"
|
||||
class="absolute mt-1 border border-slate-500 rounded shadow-lg z-10">
|
||||
<div>
|
||||
<template x-for="option in videosKeys" :key="option">
|
||||
<div
|
||||
@click="videosKeysSelected = videosKeysSelected.includes(option) ? videosKeysSelected.filter(v => v !== option) : [...videosKeysSelected, option]"
|
||||
class="p-2 cursor-pointer bg-slate-900"
|
||||
:class="{ 'bg-slate-700': videosKeysSelected.includes(option) }"
|
||||
x-text="option"
|
||||
></div>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="flex flex-wrap gap-x-2 gap-y-6">
|
||||
<div class="flex flex-wrap gap-1">
|
||||
{% for video_info in videos_info %}
|
||||
<div x-show="!videoCodecError && videosKeysSelected.includes('{{ video_info.filename }}')" class="max-w-96 relative">
|
||||
<p class="absolute inset-x-0 -top-4 text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<div x-show="!videoCodecError" class="max-w-96">
|
||||
<p class="text-sm text-gray-300 bg-gray-800 px-2 rounded-t-xl truncate">{{ video_info.filename }}</p>
|
||||
<video muted loop type="video/mp4" class="object-contain w-full h-full" @canplaythrough="videoCanPlay" @timeupdate="() => {
|
||||
if (video.duration) {
|
||||
const time = video.currentTime;
|
||||
@@ -212,12 +182,12 @@
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<template x-for="(_, colIndex) in Array.from({length: columns.length}, (_, index) => index)">
|
||||
<template x-for="(_, colIndex) in Array.from({length: nColumns}, (_, index) => index)">
|
||||
<th class="border border-slate-700">
|
||||
<div class="flex gap-x-2 justify-between px-2">
|
||||
<input type="checkbox" :checked="isColumnChecked(colIndex)"
|
||||
@change="toggleColumn(colIndex)">
|
||||
<p x-text="`${columns[colIndex].key}`"></p>
|
||||
<p x-text="`${columnNames[colIndex]}`"></p>
|
||||
</div>
|
||||
</th>
|
||||
</template>
|
||||
@@ -227,10 +197,10 @@
|
||||
<template x-for="(row, rowIndex) in rows">
|
||||
<tr class="odd:bg-gray-800 even:bg-gray-900">
|
||||
<td class="border border-slate-700">
|
||||
<div class="flex gap-x-2 max-w-64 font-semibold px-1 break-all">
|
||||
<div class="flex gap-x-2 w-24 font-semibold px-1">
|
||||
<input type="checkbox" :checked="isRowChecked(rowIndex)"
|
||||
@change="toggleRow(rowIndex)">
|
||||
<p x-text="`${rowLabels[rowIndex]}`"></p>
|
||||
<p x-text="`Motor ${rowIndex}`"></p>
|
||||
</div>
|
||||
</td>
|
||||
<template x-for="(cell, colIndex) in row">
|
||||
@@ -252,20 +222,16 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const parentOrigin = "https://huggingface.co";
|
||||
const searchParams = new URLSearchParams();
|
||||
searchParams.set("dataset", "{{ dataset_info.repo_id }}");
|
||||
searchParams.set("episode", "{{ episode_id }}");
|
||||
window.parent.postMessage({ queryString: searchParams.toString() }, parentOrigin);
|
||||
</script>
|
||||
|
||||
<script>
|
||||
function createAlpineData() {
|
||||
return {
|
||||
// state
|
||||
dygraph: null,
|
||||
currentFrameData: null,
|
||||
columnNames: ["state", "action", "pred action"],
|
||||
nColumns: 2,
|
||||
nStates: 0,
|
||||
nActions: 0,
|
||||
checked: [],
|
||||
dygraphTime: 0.0,
|
||||
dygraphIndex: 0,
|
||||
@@ -275,11 +241,6 @@
|
||||
nVideos: {{ videos_info | length }},
|
||||
nVideoReadyToPlay: 0,
|
||||
videoCodecError: false,
|
||||
isVideosDropdownOpen: false,
|
||||
videosKeys: {{ videos_info | map(attribute='filename') | list | tojson }},
|
||||
videosKeysSelected: [],
|
||||
columns: {{ columns | tojson }},
|
||||
rowLabels: {{ columns | tojson }}.reduce((colA, colB) => colA.value.length > colB.value.length ? colA : colB).value,
|
||||
|
||||
// alpine initialization
|
||||
init() {
|
||||
@@ -289,19 +250,11 @@
|
||||
if(!canPlayVideos){
|
||||
this.videoCodecError = true;
|
||||
}
|
||||
this.videosKeysSelected = this.videosKeys.map(opt => opt)
|
||||
|
||||
// process CSV data
|
||||
const csvDataStr = {{ episode_data_csv_str|tojson|safe }};
|
||||
// Create a Blob with the CSV data
|
||||
const blob = new Blob([csvDataStr], { type: 'text/csv;charset=utf-8;' });
|
||||
// Create a URL for the Blob
|
||||
const csvUrl = URL.createObjectURL(blob);
|
||||
|
||||
// process CSV data
|
||||
this.videos = document.querySelectorAll('video');
|
||||
this.video = this.videos[0];
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), csvUrl, {
|
||||
this.dygraph = new Dygraph(document.getElementById("graph"), '{{ ep_csv_url }}', {
|
||||
pixelsPerPoint: 0.01,
|
||||
legend: 'always',
|
||||
labelsDiv: document.getElementById('labels'),
|
||||
@@ -322,17 +275,21 @@
|
||||
this.colors = this.dygraph.getColors();
|
||||
this.checked = Array(this.colors.length).fill(true);
|
||||
|
||||
const seriesNames = this.dygraph.getLabels().slice(1);
|
||||
this.nStates = seriesNames.findIndex(item => item.startsWith('action_'));
|
||||
this.nActions = seriesNames.length - this.nStates;
|
||||
const colors = [];
|
||||
let lightness = 30; // const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
for(const column of this.columns){
|
||||
const nValues = column.value.length;
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/nValues)) {
|
||||
const color = `hsl(${hue}, 100%, ${lightness}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
lightness += 35;
|
||||
const LIGHTNESS = [30, 65, 85]; // state_lightness, action_lightness, pred_action_lightness
|
||||
// colors for "state" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nStates)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[0]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
// colors for "action" lines
|
||||
for (let hue = 0; hue < 360; hue += parseInt(360/this.nActions)) {
|
||||
const color = `hsl(${hue}, 100%, ${LIGHTNESS[1]}%)`;
|
||||
colors.push(color);
|
||||
}
|
||||
|
||||
this.dygraph.updateOptions({ colors });
|
||||
this.colors = colors;
|
||||
|
||||
@@ -359,19 +316,17 @@
|
||||
return [];
|
||||
}
|
||||
const rows = [];
|
||||
const nRows = Math.max(...this.columns.map(column => column.value.length));
|
||||
const nRows = Math.max(this.nStates, this.nActions);
|
||||
let rowIndex = 0;
|
||||
while(rowIndex < nRows){
|
||||
const row = [];
|
||||
// number of states may NOT match number of actions. In this case, we null-pad the 2D array to make a fully rectangular 2d array
|
||||
const nullCell = { isNull: true };
|
||||
const stateValueIdx = rowIndex;
|
||||
const actionValueIdx = stateValueIdx + this.nStates; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
// row consists of [state value, action value]
|
||||
let idx = rowIndex;
|
||||
for(const column of this.columns){
|
||||
const nColumn = column.value.length;
|
||||
row.push(rowIndex < nColumn ? this.currentFrameData[idx] : nullCell);
|
||||
idx += nColumn; // because this.currentFrameData = [state0, state1, ..., stateN, action0, action1, ..., actionN]
|
||||
}
|
||||
row.push(rowIndex < this.nStates ? this.currentFrameData[stateValueIdx] : nullCell); // push "state value" to row
|
||||
row.push(rowIndex < this.nActions ? this.currentFrameData[actionValueIdx] : nullCell); // push "action value" to row
|
||||
rowIndex += 1;
|
||||
rows.push(row);
|
||||
}
|
||||
|
||||
28
poetry.lock
generated
28
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "absl-py"
|
||||
@@ -1294,10 +1294,6 @@ files = [
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:78656d3ae1282a142a5fed410ec3a6f725fdf8d9f9192ed673e336ea3b083e12"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:681e22c8ecb3b48d11cb9019f8a32d4ae1e353e20d4ce3a0f0eedd0ccbd95e5f"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4598572bab6f726ec41fabb43bf0f7e3cf8082ea0f6f8f4e57845a6c919f31b3"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:157fc1fed50946646f09df75c6d52198735a5973e53d252199bbb1c65e1594d2"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_armv7l.whl", hash = "sha256:7ae2724c181be10692c24fb8d9ce2a99a9afc57237332c3658e2ea6f4f33c091"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_i686.whl", hash = "sha256:3d324835f292edd81b962f8c0df44f7f47c0a6f8fe6f7d081951aeb1f5ba57d2"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:474c087b5e584293685a7d4837165b2ead96dc74fb435ae50d5fa0ac168a0de0"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:297350f05f5f87a0bf647a1e5b4446728e5f800788c6bb28b462bcd167f1de7f"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:b1870a8e30f0ac298d17fd546224348d13a648bcfa0cbc51dba7e5136c1af928"},
|
||||
{file = "dora_rs-0.3.6-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:182a189212d41be0c960fd3299bf6731af2e771f8858cfb1be7ebcc17d60a254"},
|
||||
@@ -4928,8 +4924,6 @@ files = [
|
||||
{file = "PyAudio-0.2.14-cp311-cp311-win_amd64.whl", hash = "sha256:bbeb01d36a2f472ae5ee5e1451cacc42112986abe622f735bb870a5db77cf903"},
|
||||
{file = "PyAudio-0.2.14-cp312-cp312-win32.whl", hash = "sha256:5fce4bcdd2e0e8c063d835dbe2860dac46437506af509353c7f8114d4bacbd5b"},
|
||||
{file = "PyAudio-0.2.14-cp312-cp312-win_amd64.whl", hash = "sha256:12f2f1ba04e06ff95d80700a78967897a489c05e093e3bffa05a84ed9c0a7fa3"},
|
||||
{file = "PyAudio-0.2.14-cp313-cp313-win32.whl", hash = "sha256:95328285b4dab57ea8c52a4a996cb52be6d629353315be5bfda403d15932a497"},
|
||||
{file = "PyAudio-0.2.14-cp313-cp313-win_amd64.whl", hash = "sha256:692d8c1446f52ed2662120bcd9ddcb5aa2b71f38bda31e58b19fb4672fffba69"},
|
||||
{file = "PyAudio-0.2.14-cp38-cp38-win32.whl", hash = "sha256:858caf35b05c26d8fc62f1efa2e8f53d5fa1a01164842bd622f70ddc41f55000"},
|
||||
{file = "PyAudio-0.2.14-cp38-cp38-win_amd64.whl", hash = "sha256:2dac0d6d675fe7e181ba88f2de88d321059b69abd52e3f4934a8878e03a7a074"},
|
||||
{file = "PyAudio-0.2.14-cp39-cp39-win32.whl", hash = "sha256:f745109634a7c19fa4d6b8b7d6967c3123d988c9ade0cd35d4295ee1acdb53e9"},
|
||||
@@ -5896,27 +5890,27 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "rerun-sdk"
|
||||
version = "0.21.0"
|
||||
version = "0.18.2"
|
||||
description = "The Rerun Logging SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = "<3.13,>=3.8"
|
||||
files = [
|
||||
{file = "rerun_sdk-0.21.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1e454ceea31c70ae9ec1bb26eaa82828661b7657ab4d2261ca0b94006d6a1975"},
|
||||
{file = "rerun_sdk-0.21.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:84ecb77b0b5bac71b53e849801ff073de89fcd2f1e0ca0da62fb18fcbeceadf0"},
|
||||
{file = "rerun_sdk-0.21.0-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:919d921165c3238490dbe5bf00a062c68fdd2c54dc14aac6a1914c82edb5d9c8"},
|
||||
{file = "rerun_sdk-0.21.0-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:897649aadcab7014b78096f93c84c61c00a227b80adaf0dec279924b5aab53d8"},
|
||||
{file = "rerun_sdk-0.21.0-cp38-abi3-win_amd64.whl", hash = "sha256:2060bdb536a198f0f04789ba5ba771e66587e7851d668b3dfab257a5efa16819"},
|
||||
{file = "rerun_sdk-0.18.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bc4e73275f428e4e9feb8e85f88db7a9fd18b997b1570de62f949a926978f1b2"},
|
||||
{file = "rerun_sdk-0.18.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:efbba40a59710ae83607cb0dc140398a35979c2d2acf5190c9def2ac4697f6a8"},
|
||||
{file = "rerun_sdk-0.18.2-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:2a5e3b618b6d1bfde09bd5614a898995f3c318cc69d8f6d569924a2cd41536ce"},
|
||||
{file = "rerun_sdk-0.18.2-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:8fdfc4c51ef2e75cb68d39e56f0d7c196eff250cb9a0260c07d5e2d6736e31b0"},
|
||||
{file = "rerun_sdk-0.18.2-cp38-abi3-win_amd64.whl", hash = "sha256:c929ade91d3be301b26671b25e70fb529524ced915523d266641c6fc667a1eb5"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
attrs = ">=23.1.0"
|
||||
numpy = ">=1.23"
|
||||
numpy = ">=1.23,<2"
|
||||
pillow = ">=8.0.0"
|
||||
pyarrow = ">=14.0.2"
|
||||
typing-extensions = ">=4.5"
|
||||
|
||||
[package.extras]
|
||||
notebook = ["rerun-notebook (==0.21.0)"]
|
||||
notebook = ["rerun-notebook (==0.18.2)"]
|
||||
tests = ["pytest (==7.1.2)"]
|
||||
|
||||
[[package]]
|
||||
@@ -7575,4 +7569,4 @@ xarm = ["gym-xarm"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "ee60d9251f6a6253d0c371707a72a500a6053d7925c6898e6663d9320ad11503"
|
||||
content-hash = "41344f0eb2d06d9a378abcd10df8205aa3926ff0a08ac5ab1a0b1bcae7440fd8"
|
||||
|
||||
@@ -57,7 +57,7 @@ pytest-cov = {version = ">=5.0.0", optional = true}
|
||||
datasets = ">=2.19.0"
|
||||
imagecodecs = { version = ">=2024.1.1", optional = true }
|
||||
pyav = ">=12.0.5"
|
||||
rerun-sdk = ">=0.21.0"
|
||||
rerun-sdk = ">=0.15.1"
|
||||
deepdiff = ">=7.0.1"
|
||||
flask = ">=3.0.3"
|
||||
pandas = {version = ">=2.2.2", optional = true}
|
||||
|
||||
251
tests/test_train_hilserl_classifier.py
Normal file
251
tests/test_train_hilserl_classifier.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig
|
||||
from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier
|
||||
from lerobot.scripts.train_hilserl_classifier import (
|
||||
create_balanced_sampler,
|
||||
train,
|
||||
train_epoch,
|
||||
validate,
|
||||
)
|
||||
|
||||
|
||||
class MockDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.meta = MagicMock()
|
||||
self.meta.stats = {}
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
|
||||
def make_dummy_model():
|
||||
model_config = ClassifierConfig(num_classes=2, model_name="hf-tiny-model-private/tiny-random-ResNetModel")
|
||||
model = Classifier(config=model_config)
|
||||
return model
|
||||
|
||||
|
||||
def test_create_balanced_sampler():
|
||||
# Mock dataset with imbalanced classes
|
||||
data = [
|
||||
{"label": 0},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 0},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
{"label": 1},
|
||||
]
|
||||
dataset = MockDataset(data)
|
||||
cfg = MagicMock()
|
||||
cfg.training.label_key = "label"
|
||||
|
||||
sampler = create_balanced_sampler(dataset, cfg)
|
||||
|
||||
# Get weights from the sampler
|
||||
weights = sampler.weights.float()
|
||||
|
||||
# Check that samples have appropriate weights
|
||||
labels = [item["label"] for item in data]
|
||||
class_counts = torch.tensor([labels.count(0), labels.count(1)], dtype=torch.float32)
|
||||
class_weights = 1.0 / class_counts
|
||||
expected_weights = torch.tensor([class_weights[label] for label in labels], dtype=torch.float32)
|
||||
|
||||
# Test that the weights are correct
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_train_epoch():
|
||||
model = make_dummy_model()
|
||||
# Mock components
|
||||
model.train = MagicMock()
|
||||
|
||||
train_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
optimizer = MagicMock()
|
||||
grad_scaler = MagicMock()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
step = 0
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call the function under test
|
||||
train_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
grad_scaler,
|
||||
device,
|
||||
logger,
|
||||
step,
|
||||
cfg,
|
||||
)
|
||||
|
||||
# Check that model.train() was called
|
||||
model.train.assert_called_once()
|
||||
|
||||
# Check that optimizer.zero_grad() was called
|
||||
optimizer.zero_grad.assert_called()
|
||||
|
||||
# Check that logger.log_dict was called
|
||||
logger.log_dict.assert_called()
|
||||
|
||||
|
||||
def test_validate():
|
||||
model = make_dummy_model()
|
||||
|
||||
# Mock components
|
||||
model.eval = MagicMock()
|
||||
val_loader = [
|
||||
{
|
||||
"image": torch.rand(2, 3, 224, 224),
|
||||
"label": torch.tensor([0.0, 1.0]),
|
||||
}
|
||||
]
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
device = torch.device("cpu")
|
||||
logger = MagicMock()
|
||||
cfg = MagicMock()
|
||||
cfg.training.image_key = "image"
|
||||
cfg.training.label_key = "label"
|
||||
cfg.training.use_amp = False
|
||||
|
||||
# Call validate
|
||||
accuracy, eval_info = validate(model, val_loader, criterion, device, logger, cfg)
|
||||
|
||||
# Check that model.eval() was called
|
||||
model.eval.assert_called_once()
|
||||
|
||||
# Check accuracy/eval_info are calculated and of the correct type
|
||||
assert isinstance(accuracy, float)
|
||||
assert isinstance(eval_info, dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("resume", [True, False])
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.init_hydra_config")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_checkpoint_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger.get_last_pretrained_model_dir")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.Logger")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.LeRobotDataset")
|
||||
@patch("lerobot.scripts.train_hilserl_classifier.make_policy")
|
||||
def test_resume_function(
|
||||
mock_make_policy,
|
||||
mock_dataset,
|
||||
mock_logger,
|
||||
mock_get_last_pretrained_model_dir,
|
||||
mock_get_last_checkpoint_dir,
|
||||
mock_init_hydra_config,
|
||||
resume,
|
||||
):
|
||||
# Initialize Hydra
|
||||
test_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
config_dir = os.path.abspath(os.path.join(test_file_dir, "..", "lerobot", "configs", "policy"))
|
||||
assert os.path.exists(config_dir), f"Config directory does not exist at {config_dir}"
|
||||
|
||||
with initialize_config_dir(config_dir=config_dir, job_name="test_app", version_base="1.2"):
|
||||
cfg = compose(
|
||||
config_name="reward_classifier",
|
||||
overrides=[
|
||||
"device=cpu",
|
||||
"seed=42",
|
||||
f"output_dir={tempfile.mkdtemp()}",
|
||||
"wandb.enable=False",
|
||||
f"resume={resume}",
|
||||
"dataset_repo_id=dataset_repo_id",
|
||||
"train_split_proportion=0.8",
|
||||
"training.num_workers=0",
|
||||
"training.batch_size=2",
|
||||
"training.image_key=image",
|
||||
"training.label_key=label",
|
||||
"training.use_amp=False",
|
||||
"training.num_epochs=1",
|
||||
"eval.batch_size=2",
|
||||
],
|
||||
)
|
||||
|
||||
# Mock the init_hydra_config function to return cfg
|
||||
mock_init_hydra_config.return_value = cfg
|
||||
|
||||
# Mock dataset
|
||||
dataset = MockDataset([{"image": torch.rand(3, 224, 224), "label": i % 2} for i in range(10)])
|
||||
mock_dataset.return_value = dataset
|
||||
|
||||
# Mock checkpoint handling
|
||||
mock_checkpoint_dir = MagicMock(spec=Path)
|
||||
mock_checkpoint_dir.exists.return_value = resume # Only exists if resuming
|
||||
mock_get_last_checkpoint_dir.return_value = mock_checkpoint_dir
|
||||
mock_get_last_pretrained_model_dir.return_value = Path(tempfile.mkdtemp())
|
||||
|
||||
# Mock logger
|
||||
logger = MagicMock()
|
||||
resumed_step = 1000
|
||||
if resume:
|
||||
logger.load_last_training_state.return_value = resumed_step
|
||||
else:
|
||||
logger.load_last_training_state.return_value = 0
|
||||
mock_logger.return_value = logger
|
||||
|
||||
# Instantiate the model and set make_policy to return it
|
||||
model = make_dummy_model()
|
||||
mock_make_policy.return_value = model
|
||||
|
||||
# Call train
|
||||
train(cfg)
|
||||
|
||||
# Check that checkpoint handling methods were called
|
||||
if resume:
|
||||
mock_get_last_checkpoint_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_get_last_pretrained_model_dir.assert_called_once_with(Path(cfg.output_dir))
|
||||
mock_checkpoint_dir.exists.assert_called_once()
|
||||
logger.load_last_training_state.assert_called_once()
|
||||
else:
|
||||
mock_get_last_checkpoint_dir.assert_not_called()
|
||||
mock_get_last_pretrained_model_dir.assert_not_called()
|
||||
mock_checkpoint_dir.exists.assert_not_called()
|
||||
logger.load_last_training_state.assert_not_called()
|
||||
|
||||
# Collect the steps from logger.log_dict calls
|
||||
train_log_calls = logger.log_dict.call_args_list
|
||||
|
||||
# Extract the steps used in the train logging
|
||||
steps = []
|
||||
for call in train_log_calls:
|
||||
mode = call.kwargs.get("mode", call.args[2] if len(call.args) > 2 else None)
|
||||
if mode == "train":
|
||||
step = call.kwargs.get("step", call.args[1] if len(call.args) > 1 else None)
|
||||
steps.append(step)
|
||||
|
||||
expected_start_step = resumed_step if resume else 0
|
||||
|
||||
# Calculate expected_steps
|
||||
train_size = int(cfg.train_split_proportion * len(dataset))
|
||||
batch_size = cfg.training.batch_size
|
||||
num_batches = (train_size + batch_size - 1) // batch_size
|
||||
|
||||
expected_steps = [expected_start_step + i for i in range(num_batches)]
|
||||
|
||||
assert steps == expected_steps, f"Expected steps {expected_steps}, got {steps}"
|
||||
@@ -14,25 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import DatasetCard
|
||||
|
||||
from lerobot.common.datasets.utils import create_lerobot_dataset_card
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
def test_default_parameters():
|
||||
card = create_lerobot_dataset_card()
|
||||
assert isinstance(card, DatasetCard)
|
||||
assert card.data.tags == ["LeRobot"]
|
||||
assert card.data.task_categories == ["robotics"]
|
||||
assert card.data.configs == [
|
||||
{
|
||||
"config_name": "default",
|
||||
"data_files": "data/*/*.parquet",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_with_tags():
|
||||
tags = ["tag1", "tag2"]
|
||||
card = create_lerobot_dataset_card(tags=tags)
|
||||
assert card.data.tags == ["LeRobot", "tag1", "tag2"]
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
Reference in New Issue
Block a user