51 Commits

Author SHA1 Message Date
Karl Pertsch
b84cc75031 add binning jointpos 2025-04-25 05:28:23 +00:00
Karl Pertsch
c23bc86a0a load droid sim eval policies without credentials (#440)
small change to enable loading from the openpi sim eval bucket without credentials (for joint pos policies)
2025-04-17 15:39:53 -04:00
Arhan Jain
fe5d5580a4 load droid sim eval policies without credentials 2025-04-17 12:26:06 -07:00
Karl Pertsch
650b02e4ca add diffusion jointpos policy 2025-04-17 13:19:48 +00:00
Karl Pertsch
e43516e719 add diffusion droid policy 2025-04-14 20:15:23 +00:00
Karl Pertsch
20d63d47b7 additional policy 2025-04-14 19:18:09 +00:00
Karl Pertsch
1ce9ffe134 add DROID policies 2025-04-14 18:42:57 +00:00
Jimmy Tanner
36dc3c037e Update README.md to include install (#377)
readme was missing the install command which caused me a lot of issues
to figure out. Run sync then install into conda environment
2025-03-20 09:49:40 -07:00
Rhythm Syed
bb7a3b4a3e Update README.md 2025-03-14 11:19:11 -04:00
Rhythm Syed
f37e6dd7fb Update README.md 2025-03-14 02:41:27 -04:00
Rhythm Syed
eb28153241 Update README.md 2025-03-14 02:22:12 -04:00
Karl Pertsch
16affa3bee Fix typo in DROID README.md 2025-03-10 13:36:29 -04:00
Niccolo Fusai
581e07d73a tweaking comment (#357) 2025-03-01 22:29:49 -08:00
niccolofusai13
6c514a6d8a tweaking comment 2025-03-02 05:50:52 +00:00
Karl Pertsch
92b1082442 revert libero delta action change (#351) 2025-02-28 14:26:01 -05:00
Jimmy Tanner
f1b9f4ab71 Fix docker context (#330)
### Description

This pull request fixes a build failure caused by an incorrect context
in a Docker Compose configuration file.
Additionally, it slightly expands the Docker instructions.

### Changes

- Corrected the context in a Docker Compose file to fix the build issue.
- Expanded the Docker instructions.
2025-02-28 08:38:57 -08:00
Karl Pertsch
a4b1bf92f1 revert libero delta action change 2025-02-28 15:24:32 +00:00
Karl Pertsch
31289dbd72 add UR5 example (#346) 2025-02-28 10:10:31 -05:00
Karl Pertsch
cd0e9a2e0e add UR5 example 2025-02-28 14:35:55 +00:00
Karl Pertsch
620a56a399 add norm stats reloading documentation (#345) 2025-02-28 09:35:16 -05:00
Karl Pertsch
42e4838aca add norm stats reloading documentation 2025-02-28 13:54:22 +00:00
Davide De Benedittis
3409be890e Update Docker instructions 2025-02-28 11:15:42 +01:00
Davide De Benedittis
d139c700e4 Fix Docker context 2025-02-28 11:15:42 +01:00
Karl Pertsch
d0b6231bd3 more documentation for libero examples (#344) 2025-02-27 19:34:39 -05:00
Karl Pertsch
4a10482dfb more documentation for libero examples 2025-02-27 23:24:34 +00:00
uzhilinsky
bf25a4d9c4 Update .gitmodules (#316)
Change the connection to https to facilitate use by people without git
repo permissions
2025-02-20 08:07:05 -08:00
LMCallMe
7dccd73b37 Update .gitmodules 2025-02-19 19:05:15 +08:00
uzhilinsky
29068dd274 Invalidate the pi0_aloha_pen_uncap checkpoint (#308) 2025-02-17 12:13:37 -08:00
Ury Zhilinsky
ba68b3d97b Invalidate the pi0_aloha_pen_uncap checkpoint 2025-02-17 11:41:14 -08:00
Michael Equi
cd82848a99 Update convert aloha to save images in RGB (#304) 2025-02-14 11:15:58 -08:00
Michael Equi
8d288e4b41 Update convert aloha to save images in RGB 2025-02-14 18:28:41 +00:00
uzhilinsky
90b87cc42c Add a few extra connections to prevent exceeding the pool size (#300)
This takes care of the following warning:
```
WARNING:urllib3.connectionpool:Connection pool is full, discarding connection: openpi-assets.s3.us-west-1.amazonaws.com. Connection pool size: 16
```
2025-02-13 20:36:22 -08:00
Ury Zhilinsky
0a67d46b0d Add a few extra connections to prevent exceeding the pool size 2025-02-13 19:12:23 -08:00
Karl Pertsch
80d346ea0d cap DROID execution frequency (#282) 2025-02-08 19:19:57 -05:00
Karl Pertsch
16788f847e cap DROID execution frequency 2025-02-08 23:33:52 +00:00
Jimmy Tanner
f00207b91c More detailed installation instructions for docker (#277)
Adds some important docker installation details that were undocumented, because we've gotten accustomed to our convenience scripts.

This should resolve #270.
2025-02-08 09:52:55 -08:00
Haohuan Wang
b7c8bf24d4 remove threefry setting in jax 0.5.0 (#279)
continuation conversation from: https://app.graphite.dev/github/pr/Physical-Intelligence/monopi/6672/upgrade-jax-to-0-5-0?utm_source=gt-slack-notif&panel=timeline#comment-PRRC_kwDOLnRTkc50DYoL
2025-02-08 09:52:28 -08:00
Haohuan Wang
ed05e55074 remove threefry setting in jax 0.5.0 2025-02-07 21:27:02 +00:00
Misha Lvovsky
007e2b91ed added lora fast model support (#274)
* added lora fast model support

* small config mistake

* change to ConfigDict types instead of Any as suggested in PR #274 discussion https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945632119

* Simplify get_freeze_filter as per comment on PR #274
https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* actually pass the configs haha https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* update test to check if lora params are present https://github.com/Physical-Intelligence/openpi/pull/274#discussion_r1945722808

* Fixed test to use nnx filters so that it is more clean

* run formatter
2025-02-07 11:58:01 -08:00
Jimmy Tanner
fa5cf91df1 More detailed installation instructions for docker 2025-02-07 11:24:17 -08:00
uzhilinsky
2a13ed7eff Document fix (#276) 2025-02-06 17:09:29 -08:00
Ikko Eltociear Ashimine
9c1376bcc1 docs: update simple_client/README.md (#272)
specifiy -> specify
2025-02-07 00:53:50 +00:00
Ury Zhilinsky
9675e12c4e Document fix 2025-02-06 16:27:00 -08:00
uzhilinsky
bf30fa3d4c Force pi0_libero to be re-downloaded (#275) 2025-02-06 16:14:18 -08:00
uzhilinsky
f543cb1d87 Use data parallel sharding by default (#267)
Our model expects that and so this is a reasonable default to use out of the box.
2025-02-05 23:15:24 -08:00
uzhilinsky
6104624aca Use us-west-1 by default (#266) 2025-02-05 16:02:24 -08:00
Ury Zhilinsky
06cdf3a27f Use us-west-1 by default 2025-02-05 15:43:15 -08:00
Jimmy Tanner
f8ce5c9479 Use new runner (#265) 2025-02-05 15:11:03 -08:00
uzhilinsky
153e34cefe Add link to public aloha data in README.md (#259)
This PR adds a missing link to the public ALOHA data of the pen uncap
task etc.
2025-02-05 14:58:52 -08:00
Jimmy Tanner
f61cd24a15 Use new runner 2025-02-05 14:50:49 -08:00
Oier Mees
ed11c29742 Add link to public aloha data in README.md 2025-02-04 16:01:50 -08:00
24 changed files with 1405 additions and 106 deletions

View File

@@ -7,7 +7,7 @@ on:
jobs:
run_tests:
name: Run Tests
runs-on: verylarge
runs-on: openpi-verylarge
env:
GIT_LFS_SKIP_SMUDGE: true
steps:

4
.gitmodules vendored
View File

@@ -1,6 +1,6 @@
[submodule "third_party/aloha"]
path = third_party/aloha
url = git@github.com:Physical-Intelligence/aloha.git
url = https://github.com/Physical-Intelligence/aloha.git
[submodule "third_party/libero"]
path = third_party/libero
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git

View File

@@ -38,6 +38,7 @@ We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [
```bash
GIT_LFS_SKIP_SMUDGE=1 uv sync
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
```
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
@@ -66,7 +67,7 @@ We also provide "expert" checkpoints for various robot platforms and tasks. Thes
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `s3://openpi-assets/checkpoints/pi0_droid` |
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `s3://openpi-assets/checkpoints/pi0_aloha_towel` |
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `s3://openpi-assets/checkpoints/pi0_aloha_tupperware` |
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](XXX), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
By default, checkpoints are automatically downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
@@ -148,6 +149,8 @@ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp
The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
### 3. Spinning up a policy server and running inference
Once training is complete, we can run inference by spinning up a policy server and then querying it from a Libero evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
@@ -158,12 +161,16 @@ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md).
If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
### More Examples
We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
- [ALOHA Simulator](examples/aloha_sim)
- [ALOHA Real](examples/aloha_real)
- [UR5](examples/ur5)

View File

@@ -2,6 +2,24 @@
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
Build the Docker image and start the container with the following command:
```bash
docker compose -f scripts/docker/compose.yml up --build
```
To build and run the Docker image for a specific example, use the following command:
```bash
docker compose -f examples/<example_name>/compose.yml up --build
```
where `<example_name>` is the name of the example you want to run.
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.

69
docs/norm_stats.md Normal file
View File

@@ -0,0 +1,69 @@
# Normalization statistics
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
## Reloading normalization statistics
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
```python
TrainConfig(
...
data=LeRobotAlohaDataConfig(
...
assets=AssetsConfig(
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
asset_id="trossen",
),
),
)
```
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
## Provided Pre-training Normalization Statistics
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_fast_base/assets`.
| Robot | Description | Asset ID |
|-------|-------------|----------|
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
## Pi0 Model Action Space Definitions
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
```
"dim_0:dim_5": "left arm joint angles",
"dim_6": "left arm gripper position",
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
"dim_13": "right arm gripper position (for bi-manual only)",
# For mobile robots:
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
```
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
General info for Pi robots:
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.

View File

@@ -33,10 +33,39 @@ pip install -e .
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
```python
from openpi_client import image_tools
from openpi_client import websocket_client_policy
policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
action_chunk = policy_client.infer(example)["actions"]
# Outside of episode loop, initialize the policy client.
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
for step in range(num_steps):
# Inside the episode loop, construct the observation.
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
# The typical resize_size for pre-trained pi0 models is 224.
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
observation = {
"observation/image": image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, 224, 224)
),
"observation/wrist_image": image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, 224, 224)
),
"observation/state": state,
"prompt": task_instruction,
}
# Call the policy server with the current observation.
# This returns an action chunk of shape (action_horizon, action_dim).
# Note that you typically only need to call the policy every N steps and execute steps
# from the predicted action chunk open-loop in the remaining steps.
action_chunk = client.infer(observation)["actions"]
# Execute the actions in the environment.
...
```
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).

View File

@@ -1,6 +1,6 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
## Prerequisites
@@ -28,13 +28,13 @@ uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python examples/aloha_real/main.py
python -m examples.aloha_real.main
```
Terminal window 2:
```bash
roslaunch --wait aloha ros_nodes.launch
roslaunch aloha ros_nodes.launch
```
Terminal window 3:
@@ -123,4 +123,4 @@ This task involves opening a tupperware filled with food and pouring the content
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.

View File

@@ -155,7 +155,7 @@ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, n
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.imdecode(data, 1))
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array

View File

@@ -27,7 +27,7 @@ uv run scripts/serve_policy.py --env=DROID
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
```bash

View File

@@ -6,7 +6,7 @@ import datetime
import faulthandler
import os
import signal
import time
from moviepy.editor import ImageSequenceClip
import numpy as np
from openpi_client import image_tools
@@ -19,6 +19,9 @@ import tyro
faulthandler.enable()
# DROID data collection frequency -- we slow down execution to match this frequency
DROID_CONTROL_FREQUENCY = 15
@dataclasses.dataclass
class Args:
@@ -95,6 +98,7 @@ def main(args: Args):
bar = tqdm.tqdm(range(args.max_timesteps))
print("Running rollout... press Ctrl+C to stop early.")
for t_step in bar:
start_time = time.time()
try:
# Get the current observation
curr_obs = _extract_observation(
@@ -145,6 +149,11 @@ def main(args: Args):
action = np.clip(action, -1, 1)
env.step(action)
# Sleep to match DROID data collection frequency
elapsed_time = time.time() - start_time
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
except KeyboardInterrupt:
break

View File

@@ -2,7 +2,7 @@
A minimal client that sends observations to the server and prints the inference rate.
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
@@ -27,4 +27,4 @@ Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```
```

151
examples/ur5/README.md Normal file
View File

@@ -0,0 +1,151 @@
# UR5 Example
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
```python
@dataclasses.dataclass(frozen=True)
class UR5Inputs(transforms.DataTransformFn):
action_dim: int
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
mask_padding = self.model_type == _model.ModelType.PI0
# First, concatenate the joints and gripper into the state vector.
# Pad to the expected input dimensionality of the model (same as action_dim).
state = np.concatenate([data["joints"], data["gripper"]])
state = transforms.pad_to_dim(state, self.action_dim)
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
base_image = _parse_image(data["base_rgb"])
wrist_image = _parse_image(data["wrist_rgb"])
# Create inputs dict.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Since there is no right wrist, replace with zeros
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Since the "slot" for the right wrist is not used, this mask is set
# to False
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
},
}
# Pad actions to the model action dimension.
if "actions" in data:
# The robot produces 7D actions (6 DoF + 1 gripper), and we pad these.
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
inputs["actions"] = actions
# Pass the prompt (aka language instruction) to the model.
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class UR5Outputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
return {"actions": np.asarray(data["actions"][:, :7])}
```
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
```python
@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"base_rgb": "image",
"wrist_rgb": "wrist_image",
"joints": "joints",
"gripper": "gripper",
"prompt": "prompt",
}
)
]
)
# These transforms are the ones we wrote earlier.
data_transforms = _transforms.Group(
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[UR5Outputs()],
)
# Convert absolute actions to delta actions.
# By convention, we do not convert the gripper action (7th dimension).
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
```
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
```python
TrainConfig(
name="pi0_ur5",
model=pi0.Pi0Config(),
data=LeRobotUR5DataConfig(
repo_id="your_username/ur5_dataset",
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
# Reloading normalization stats can help transfer pre-trained models to new environments.
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
assets=AssetsConfig(
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
asset_id="ur5e",
),
base_config=DataConfig(
local_files_only=True, # True, if dataset is saved locally.
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. The recommended setting is True.
prompt_from_task=True,
),
),
# Load the pi0 base model checkpoint.
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
)
```

View File

@@ -1,10 +1,10 @@
# Run with:
# docker compose -f scripts/compose.yml up --build
# docker compose -f scripts/docker/compose.yml up --build
services:
openpi_server:
image: openpi_server
build:
context: ..
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true

View File

@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)

View File

@@ -0,0 +1,466 @@
import math
from typing import Literal
import chex
from einops import einops
from flax import linen as nn
from flax.linen.module import Module
from flax.linen.module import compact
from flax.struct import dataclass
from flax.typing import Array
import jax
import jax.numpy as jnp
class FsqCodebook(nn.Module):
input_dim: int
target_codebook_size: int
codebook_type: Literal["fsq", "lfq"]
_bins_per_dim: tuple[int] | None = None
@property
def bins_per_dim(self):
if self._bins_per_dim is not None:
return self._bins_per_dim
if self.codebook_type == "fsq":
return self._get_bins_fsq(self.target_codebook_size)
elif self.codebook_type == "lfq": # noqa: RET505
return self._get_bins_lfq(self.target_codebook_size)
elif self.codebook_type == "custom":
return self._get_bins_custom(self.target_codebook_size)
else:
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
@property
def place_values(self):
place_values = [1]
for b in self.bins_per_dim[:-1]:
place_values.append(place_values[-1] * b)
return jnp.array(place_values)
@staticmethod
def _get_bins_fsq(target_codebook_size):
"""
Get bins per dimension based on codebook size, from the original FSQ paper.
"""
if target_codebook_size == 2**8:
return (8, 6, 5)
elif target_codebook_size == 2**10: # noqa: RET505
return (8, 5, 5, 5)
elif target_codebook_size == 2**12:
return (7, 5, 5, 5, 5)
elif target_codebook_size == 2**14:
return (8, 8, 8, 6, 5)
elif target_codebook_size == 2**16:
return (8, 8, 8, 5, 5, 5)
else:
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
@staticmethod
def _get_bins_custom(target_codebook_size):
if target_codebook_size == 2**8:
return (16, 16)
elif target_codebook_size == 2**10: # noqa: RET505
return (32, 32)
elif target_codebook_size == 2**12:
return (64, 64)
elif target_codebook_size == 2**14:
return (128, 128)
elif target_codebook_size == 2**16:
return (256, 256)
return None
@staticmethod
def _get_bins_lfq(target_codebook_size):
"""
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
"""
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
return (2,) * int(math.log2(target_codebook_size))
def setup(self):
self.proj_down = nn.Dense(len(self.bins_per_dim))
self.proj_up = nn.Dense(self.input_dim)
def __call__(self, inputs):
tokens, z = self.encode(inputs)
output = self.decode(tokens, z_grad=z)
return tokens, output
def encode(self, inputs):
bases = jnp.array(self.bins_per_dim)
x = self.proj_down(inputs)
z = jnp.tanh(x)
# Quantize
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
tokens = self.undigitize(digits)
return tokens, z
def decode(self, tokens, z_grad: jax.Array | None = None):
bases = jnp.array(self.bins_per_dim)
digits = self.digitize(tokens)
z_q = digits / (bases - 1) * 2 - 1
if z_grad is not None:
chex.assert_equal_shape([z_q, z_grad])
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
return self.proj_up(z_q)
def undigitize(self, digits):
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
def digitize(self, tokens):
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
@property
def vocab_size(self):
return math.prod(self.bins_per_dim)
class ResNetDownBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1 or x.shape[-1] != self.n_filters:
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
class ResNetUpBlock(nn.Module):
stride: int = 1
n_filters: int = 64
dropout_rate: float = 0.0
group_size: int = 32
@nn.compact
def __call__(self, x, *, train=True):
skip = x
if self.stride > 1:
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = nn.relu(x)
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
return skip + x
@dataclass
class LfqCodebookOutput:
tokens: jnp.ndarray
z: jnp.ndarray
z_q: jnp.ndarray
token_log_probs: jnp.ndarray
commit_loss: jnp.ndarray
class LookupFreeQuantization(nn.Module):
num_dims: int
latent_dim: int
def setup(self):
self.codebook = jnp.array([-1, 1])
# self.activation = lambda x: x
self.activation = nn.tanh
self.project_down = nn.Dense(self.num_dims)
self.project_up = nn.Dense(self.latent_dim)
def encode(self, z):
z = self.project_down(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
token_bits = jnp.argmin(token_squared_distances, axis=-1)
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
def decode(self, tokens):
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
return self.project_up(self.codebook[token_bits])
def loss(self, x):
z = self.project_down(x)
z = self.activation(z)
token_squared_distances = jnp.square(z[..., None] - self.codebook)
tokens = jnp.argmin(token_squared_distances, axis=-1)
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
token_bit_expansions = jnp.bitwise_and(
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
).astype(jnp.int32)
token_log_probs = (
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
+ token_bit_log_probs[..., 1] @ token_bit_expansions
) # (batch_size, num_tokens, 2 ** num_dims)
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
z_q = self.codebook[tokens]
commit_loss = jnp.square(z - z_q).mean()
z_q = jax.lax.stop_gradient(z_q - z) + z
z_q = self.project_up(z_q)
z = self.project_up(z)
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
return LfqCodebookOutput(
tokens=tokens,
z=z,
z_q=z_q,
token_log_probs=jnp.zeros(()),
commit_loss=commit_loss,
)
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
class GeGLU(Module):
"""Gated Linear Unit with GELU (GeGLU) activation function.
GeGLU is a Flax layer that combines a linear transformation with a GELU
activation function in a gating mechanism. It is often used in Transformer models
to provide non-linear capabilities while preserving a strong linear component.
Example usage::
>>> import flax.linen as nn
>>> class TransformerBlock(nn.Module):
... @nn.compact
... def __call__(self, x):
... x = nn.Dense(2)(x)
... x = nn.GeGLU()(x) # initialized
... return x
Attributes:
features: the number of output features (default: None).
"""
output_dim: int = -1
@compact
def __call__(self, inputs: Array) -> Array:
"""Applies the GeGLU activation to the inputs.
Args:
inputs: the nd-array to apply the GeGLU activation function to.
Returns:
The transformed input.
"""
if self.output_dim == -1:
output_dim = inputs.shape[-1]
else:
output_dim = self.output_dim
x = nn.Dense(output_dim * 2)(inputs)
x, gate = x[..., :output_dim], x[..., output_dim:]
return x * nn.gelu(gate)
class CrossAttentionLayer(nn.Module):
dropout_rate: float = 0.0
num_heads: int = None
causal: bool = False
mlp_ratio: float = 4.0
@nn.compact
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
d_embed = x.shape[-1]
seq_len_q = x.shape[-2]
seq_len_k = y.shape[-2]
if self.causal:
# One block size will be 1
bs_q = max(seq_len_q // seq_len_k, 1)
bs_k = max(seq_len_k // seq_len_q, 1)
mask_self = nn.make_causal_mask(x[..., 0])
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
# Self-attention block
skip = x
x = nn.LayerNorm()(x)
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
)(x, x, x, mask=mask_self)
x = skip + x
# Cross-attention block
skip = x
x = nn.LayerNorm()(x)
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads or d_embed // 64,
dropout_rate=self.dropout_rate,
deterministic=not train,
# attention_fn=partial(nn.dot_product_attention, bias=bias),
)(x, y, y, mask=mask_cross)
x = skip + x
# MLP block
skip = x
x = nn.LayerNorm()(x)
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
x = GeGLU()(x)
x = nn.Dense(d_embed)(x)
return skip + x
def sinusoidal_pe_init(_, shape):
seq_len, d_embed = shape
position = jnp.arange(0, seq_len, 1)
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
return jnp.concatenate(
[
jnp.sin(position[:, jnp.newaxis] * div_term),
jnp.cos(position[:, jnp.newaxis] * div_term),
],
axis=-1,
)
class TokenizerEncoderDecoder(nn.Module):
num_tokens: int
num_cross_tokens: int
num_layers: int
causal: bool
mlp_ratio: float = 4.0
use_state_conditioning: bool = False
@nn.compact
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
if mask is not None:
# mask is (batch_dims..., num_cross_tokens)
chex.assert_equal_shape([y[..., 0], mask])
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
else:
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
if self.use_state_conditioning:
assert state_conditioning is not None, "State conditioning is required for this model."
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
y = jnp.concatenate([y, state_embed], axis=-2)
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
for _ in range(self.num_layers):
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
x, y, train=train, mask_self=None, mask_cross=attn_mask
)
return x
class FsqAttentionTokenizer(nn.Module):
embed_dim: int
data_dim: int
data_horizon: int
num_tokens: int
num_layers: int
target_codebook_size: int
causal: bool = False
mlp_ratio: float = 2.0
bound: float | None = None
use_state_conditioning: bool = False
@property
def vocab_size(self):
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
def setup(self):
self.proj = nn.Dense(self.embed_dim)
self.encoder = TokenizerEncoderDecoder(
num_tokens=self.num_tokens,
num_cross_tokens=self.data_horizon,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.codebook = FsqCodebook(
input_dim=self.embed_dim,
target_codebook_size=self.target_codebook_size,
codebook_type="custom",
)
self.decoder = TokenizerEncoderDecoder(
num_tokens=self.data_horizon,
num_cross_tokens=self.num_tokens,
num_layers=self.num_layers,
causal=self.causal,
use_state_conditioning=self.use_state_conditioning,
mlp_ratio=self.mlp_ratio,
)
self.proj_mean = nn.Dense(self.data_dim)
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
def tokenize(self, action, *, obs=None, train=False):
if self.bound is not None:
action = jnp.clip(action, -self.bound, self.bound)
x = self.proj(action)
x = self.encoder(x, train=train, state_conditioning=obs)
return self.codebook.encode(x)
def detokenize(self, tokens, *, obs=None):
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
mean = self.proj_mean(x)
return mean * self.out_scale
def loss(self, action, *, obs=None, train=True):
# Encode
x = self.proj(action)
z = self.encoder(x, train=train, state_conditioning=obs)
# Quantize
tokens, z = self.codebook(z)
# Decode
x = self.decoder(z, train=train, state_conditioning=obs)
mean = self.proj_mean(x) * self.out_scale
mse = jnp.mean(jnp.square(action - mean))
mae = jnp.mean(jnp.abs(action - mean))
return mse, {
"mse": mse,
"mae": mae,
}
def __call__(self, *args, **kwargs):
"""
Dummy for .init
"""
return self.loss(*args, **kwargs)

View File

@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
Used for FAST autoregressive policies.
"""
import dataclasses
from typing import Literal, TypeAlias
import einops
@@ -25,9 +26,10 @@ import jax
import jax.numpy as jnp
import ml_collections
import openpi.models.lora as lora
import openpi.shared.array_typing as at
Variant = Literal["gemma_2b"]
Variant = Literal["gemma_2b", "gemma_2b_lora"]
def get_config(variant):
@@ -48,6 +50,26 @@ def get_config(variant):
"remat_policy": "nothing_saveable",
}
)
if variant == "gemma_2b_lora":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
"lora_configs": {
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
},
}
)
raise ValueError(f"Unknown variant: {variant}")
@@ -110,21 +132,34 @@ class Attention(nn.Module):
cache_dtype: str | None = None
lora_config: lora.LoRAConfig | None = None
def setup(self):
if self.num_kv_heads == self.num_heads:
self.qkv_einsum = Einsum(
self.qkv_einsum = lora.Einsum(
shape=(3, self.num_heads, self.features, self.head_dim),
name="qkv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
else:
# MQA
self.q_einsum = Einsum(
self.q_einsum = lora.Einsum(
shape=(self.num_heads, self.features, self.head_dim),
name="q_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
self.kv_einsum = Einsum(
self.kv_einsum = lora.Einsum(
shape=(2, self.num_kv_heads, self.features, self.head_dim),
name="kv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
self.attn_vec_einsum = Einsum(
self.attn_vec_einsum = lora.Einsum(
shape=(self.num_heads, self.head_dim, self.features),
name="attn_vec_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
def _init_cache(self, k, v, cache_size):
@@ -189,37 +224,6 @@ class Attention(nn.Module):
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.zeros_init(),
((2, self.features, self.hidden_dim)),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.zeros_init(),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
@@ -233,6 +237,7 @@ class Block(nn.Module):
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
cache_dtype: str | None = None
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
def setup(self):
self.pre_attention_norm = RMSNorm()
@@ -242,9 +247,12 @@ class Block(nn.Module):
features=self.embed_dim,
head_dim=self.head_dim,
cache_dtype=self.cache_dtype,
lora_config=self.lora_configs.get("attn"),
)
self.pre_ffw_norm = RMSNorm()
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
self.mlp = lora.FeedForward(
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
)
if self.dropout:
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
else:
@@ -289,6 +297,7 @@ class Module(nn.Module):
scan: bool = False
remat_policy: str = "none"
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
@nn.compact
def __call__(
@@ -380,6 +389,7 @@ class Module(nn.Module):
"dropout": self.dropout,
"dropout_bdims": self.dropout_bdims,
"cache_dtype": self.cache_dtype,
"lora_configs": self.lora_configs,
}
layers = self.scope.push("layers")
blocks = [

View File

@@ -1,3 +1,4 @@
from flax import nnx
import jax
import pytest
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)

View File

@@ -12,6 +12,7 @@ from openpi.models import model as _model
import openpi.models.gemma_fast as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
import openpi.shared.nnx_utils as nnx_utils
logger = logging.getLogger("openpi")
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
return observation_spec, action_spec
def get_freeze_filter(self) -> nnx.filterlib.Filter:
"""Returns the freeze filter based on the model config."""
if "lora" in self.paligemma_variant:
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
return nnx.Nothing
class Pi0FAST(_model.BaseModel):
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):

View File

@@ -1,4 +1,5 @@
import logging
import os
import numpy as np
import sentencepiece
@@ -125,3 +126,215 @@ class FASTTokenizer:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class BinningTokenizer:
def __init__(self, max_len: int = 256, n_bins: int = 256):
self._max_len = max_len
self._n_bins = n_bins
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
if len(action_tokens) < action_horizon * action_dim:
return np.zeros([action_horizon, action_dim], dtype=np.float32)
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
return action_tokens / self._n_bins * 2 - 1
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
class FSQTokenizer:
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
import jax
import orbax.checkpoint as ocp
import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer
self._max_len = max_len
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
# Download tokenizer
path = download.maybe_download(fsq_tokenizer_path)
tok_path = os.path.join(path, os.listdir(path)[0])
# Split step from path
step = int(tok_path.split("/")[-1])
base_path = tok_path.rsplit("/", 1)[0]
mgr = ocp.CheckpointManager(
base_path,
item_handlers={
"params": ocp.StandardCheckpointHandler(),
"opt_state": ocp.StandardCheckpointHandler(),
"config": ocp.JsonCheckpointHandler(),
},
options=ocp.CheckpointManagerOptions(max_to_keep=1),
)
try:
restored = mgr.restore(
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
)
config = restored["config"]
self._params = restored["params"]
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
except Exception as e:
raise RuntimeError(
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
) from e
# Compile tokenize and detokenize functions
self._tokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
)
self._detokenize_fn = jax.jit(
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
)
# Download base PaliGemma tokenizer
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
def tokenize(
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
cleaned_text = prompt.lower().strip().replace("_", " ")
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Convention: prefix includes prompt and string-representation of state, followed by ';'
state_str = " ".join(map(str, discretized_state))
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
if actions is not None:
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
postfix_tokens = []
# Create output token sequence & masks
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
tokens = prefix_tokens + postfix_tokens
token_mask = [True] * len(tokens)
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
# Pad tokens to max length
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [False] * (self._max_len - tokens_len)
tokens = tokens + padding
token_mask = token_mask + padding
ar_mask = ar_mask + padding
loss_mask = loss_mask + padding
else:
if len(tokens) > self._max_len:
logging.warning(
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
"Consider increasing the `max_token_len` in your model config if this happens frequently."
)
tokens = tokens[: self._max_len]
token_mask = token_mask[: self._max_len]
ar_mask = ar_mask[: self._max_len]
loss_mask = loss_mask[: self._max_len]
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
# Decode predicted output tokens
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
# Extract actions from FAST model outputs
if "Action: " not in decoded_tokens:
return np.zeros((action_horizon, action_dim), dtype=np.float32)
# Extract actions from decoded tokens
raw_action_tokens = np.array(
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
)
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
try:
import jax
# Move computation to CPU and compile on-demand
device = jax.devices("cpu")[0]
with jax.default_device(device):
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
except Exception as e:
logging.warning(f"Error decoding FSQ: {e}")
return np.zeros((action_horizon, action_dim))
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
if isinstance(tokens, list):
tokens = np.array(tokens)
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens

View File

@@ -28,45 +28,72 @@ def _parse_image(image) -> np.ndarray:
@dataclasses.dataclass(frozen=True)
class LiberoInputs(transforms.DataTransformFn):
"""
This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
the correct elements of your dataset into the model.
"""
# The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
# Do not change this for your own dataset.
action_dim: int
# Determines which model will be used.
# Do not change this for your own dataset.
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
# We only mask padding for pi0 model, not pi0-FAST. Do not change this for your own dataset.
mask_padding = self.model_type == _model.ModelType.PI0
# Get the state. We are padding from 8 to the model action dim.
# For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
# We pad the proprioceptive input to the action dimension of the model.
# For pi0-FAST, we don't pad the state. For Libero, we don't need to differentiate
# since the pi0-FAST action_dim = 7, which is < state_dim = 8, so pad is skipped.
# Keep this for your own dataset, but if your dataset stores the proprioceptive input
# in a different key than "observation/state", you should change it below.
state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference
# stores as float32 (C,H,W), gets skipped for policy inference.
# Keep this for your own dataset, but if your dataset stores the images
# in a different key than "observation/image" or "observation/wrist_image",
# you should change it below.
# Pi0 models support three image inputs at the moment: one third-person view,
# and two wrist views (left and right). If your dataset does not have a particular type
# of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
# right wrist image below.
base_image = _parse_image(data["observation/image"])
wrist_image = _parse_image(data["observation/wrist_image"])
# Create inputs dict. Do not change the keys in the dict below.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Pad any non-existent images with zero-arrays of the appropriate shape.
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Mask any non-existent images with False (if ``mask_padding`` is True).
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
},
}
# Pad actions to the model action dimension. Keep this for your own dataset.
# Actions are only available during training.
if "actions" in data:
# We are padding from 7 to the model action dim.
# We are padding to the model action dim.
# For pi0-FAST, this is a no-op (since action_dim = 7).
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
inputs["actions"] = actions
# Pass the prompt (aka language instruction) to the model.
# Keep this for your own dataset (but modify the key if the instruction is not
# stored in "prompt"; the output dict always needs to have the key "prompt").
if "prompt" in data:
inputs["prompt"] = data["prompt"]
@@ -75,6 +102,16 @@ class LiberoInputs(transforms.DataTransformFn):
@dataclasses.dataclass(frozen=True)
class LiberoOutputs(transforms.DataTransformFn):
"""
This class is used to convert outputs from the model back the the dataset specific format. It is
used for inference only.
For your own dataset, you can copy this class and modify the action dimension based on the comments below.
"""
def __call__(self, data: dict) -> dict:
# Only return the first 7 dims.
# Only return the first N actions -- since we padded actions above to fit the model action
# dimension, we need to now parse out the correct number of actions in the return dict.
# For Libero, we only return the first 7 actions (since the rest is padding).
# For your own dataset, replace `7` with the action dimension of your dataset.
return {"actions": np.asarray(data["actions"][:, :7])}

View File

@@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path:
return cache_dir
def maybe_download(url: str, **kwargs) -> pathlib.Path:
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
If the local file already exists, it will be returned directly.
@@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
Args:
url: URL to the file to download.
force_download: If True, the file will be downloaded even if it already exists in the cache.
**kwargs: Additional arguments to pass to fsspec.
Returns:
@@ -67,30 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
local_path = local_path.resolve()
# Check if file already exists in cache.
if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path):
return local_path
# Download file from remote file system.
logger.info(f"Downloading {url} to {local_path}")
with filelock.FileLock(local_path.with_suffix(".lock")):
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
# Check if the cache should be invalidated.
invalidate_cache = False
if local_path.exists():
if force_download or _should_invalidate_cache(cache_dir, local_path):
invalidate_cache = True
else:
_download_fsspec(url, scratch_path, **kwargs)
return local_path
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
try:
lock_path = local_path.with_suffix(".lock")
with filelock.FileLock(lock_path):
# Ensure consistent permissions for the lock file.
_ensure_permissions(lock_path)
# First, remove the existing cache if it is expired.
if invalidate_cache:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
# Download the data to a local cache.
logger.info(f"Downloading {url} to {local_path}")
scratch_path = local_path.with_suffix(".partial")
if _is_openpi_url(url) or _is_openpi_simeval_url(url):
# Download without credentials.
_download_boto3(
url,
scratch_path,
boto_session=boto3.Session(
region_name="us-west-1",
),
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
)
elif url.startswith("s3://"):
# Download with default boto3 credentials.
_download_boto3(url, scratch_path)
else:
_download_fsspec(url, scratch_path, **kwargs)
shutil.move(scratch_path, local_path)
_ensure_permissions(local_path)
except PermissionError as e:
msg = (
f"Local file permission error was encountered while downloading {url}. "
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
)
raise PermissionError(msg) from e
return local_path
@@ -209,7 +236,8 @@ def _download_boto3(
def _get_s3_transfer_manager(
session: boto3.Session, workers: int, botocore_config: botocore.config.Config | None = None
) -> s3_transfer.TransferManager:
config = botocore.config.Config(max_pool_connections=workers)
# Add a few extra connections to prevent exceeding the pool size.
config = botocore.config.Config(max_pool_connections=workers + 2)
if botocore_config is not None:
config = config.merge(botocore_config)
s3client = session.client("s3", config=config)
@@ -271,6 +299,9 @@ def _is_openpi_url(url: str) -> bool:
"""Check if the url is an OpenPI S3 bucket url."""
return url.startswith("s3://openpi-assets/")
def _is_openpi_simeval_url(url: str) -> bool:
"""Check if the url is an OpenPI simeval S3 bucket url."""
return url.startswith("s3://openpi-assets-simeval/")
def _get_mtime(year: int, month: int, day: int) -> float:
"""Get the mtime of a given date at midnight UTC."""
@@ -282,11 +313,13 @@ def _get_mtime(year: int, month: int, day: int) -> float:
# Partial matching will be used from top to bottom and the first match will be chosen.
# Cached entries will be retained only if they are newer than the expiration timestamp.
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
}
def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
assert local_path.exists(), f"File not found at {local_path}"
@@ -295,13 +328,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path)
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
if pattern.match(relative_path):
# Remove if not newer than the expiration timestamp.
if local_path.stat().st_mtime <= expire_time:
logger.info(f"Removing expired cached entry: {local_path}")
if local_path.is_dir():
shutil.rmtree(local_path)
else:
local_path.unlink()
return True
return False
return local_path.stat().st_mtime <= expire_time
return False

View File

@@ -48,8 +48,8 @@ def initialize_checkpoint_dir(
),
)
# special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. in this case, we don't actually want the train script to try and restore a
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
# checkpoint, since it will fail.
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")

View File

@@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory):
# If provided, will determine the default prompt that be used by the model.
default_prompt: str | None = None
fast_model_tokenizer: Any | None = None
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
match model_config.model_type:
@@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory):
],
)
case _model.ModelType.PI0_FAST:
tokenizer_cls = (
_tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer
)
tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs
return _transforms.Group(
inputs=[
_transforms.InjectDefaultPrompt(self.default_prompt),
_transforms.ResizeImages(224, 224),
_transforms.TokenizeFASTInputs(
_tokenizer.FASTTokenizer(model_config.max_token_len),
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
),
],
outputs=[
_transforms.ExtractFASTActions(
_tokenizer.FASTTokenizer(model_config.max_token_len),
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
action_horizon=model_config.action_horizon,
action_dim=model_config.action_dim,
)
@@ -251,9 +257,22 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
@dataclasses.dataclass(frozen=True)
class LeRobotLiberoDataConfig(DataConfigFactory):
"""
This config is used to configure transforms that are applied at various parts of the data pipeline.
For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
comments below.
"""
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Make inputs look like they come from the Libero environment
# The repack transform is *only* applied to the data coming from the dataset,
# and *not* during inference. We can use it to make inputs from the dataset look
# as close as possible to those coming from the inference environment (e.g. match the keys).
# Below, we match the keys in the dataset (which we defined in the data conversion script) to
# the keys we use in our inference pipeline (defined in the inference script for libero).
# For your own dataset, first figure out what keys your environment passes to the policy server
# and then modify the mappings below so your dataset's keys get matched to those target keys.
# The repack transform simply remaps key names here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
@@ -268,13 +287,29 @@ class LeRobotLiberoDataConfig(DataConfigFactory):
]
)
# Prepare data for policy training
# Convert images to uint8 numpy arrays, add masks
# The data transforms are applied to the data coming from the dataset *and* during inference.
# Below, we define the transforms for data going into the model (``inputs``) and the transforms
# for data coming out of the model (``outputs``) (the latter is only used during inference).
# We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
# how to modify the transforms to match your dataset. Once you created your own transforms, you can
# replace the transforms below with your own.
data_transforms = _transforms.Group(
inputs=[libero_policy.LiberoInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[libero_policy.LiberoOutputs()],
)
# Use delta actions (not for gripper)
# One additional data transform: pi0 models are trained on delta actions (relative to the first
# state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
# you can uncomment the following line to convert the actions to delta actions. The only exception
# is for the gripper actions which are always absolute.
# In the example below, we would apply the delta conversion to the first 6 actions (joints) and
# leave the 7th action (gripper) unchanged, i.e. absolute.
# In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
# apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
# transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
# TODO(karl): comment this out once we have updated the Libero checkpoints to not use
# the delta action transform
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
@@ -282,8 +317,10 @@ class LeRobotLiberoDataConfig(DataConfigFactory):
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
@@ -439,24 +476,178 @@ _CONFIGS = [
),
),
),
TrainConfig(
name="pi0_droid_jointpos",
model=pi0.Pi0Config(action_horizon=10),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="pi0_fast_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
droid_policy.DroidOutputs(),
],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="paligemma_binning_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
),
),
TrainConfig(
name="paligemma_binning_droid_jointpos",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
droid_policy.DroidOutputs(),
],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.BinningTokenizer,
),
),
),
TrainConfig(
name="paligemma_fast_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
TrainConfig(
name="paligemma_fast_specialist_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.FASTTokenizer,
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
),
),
),
TrainConfig(
name="paligemma_vq_droid",
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
model_transforms=ModelTransformFactory(
fast_model_tokenizer=_tokenizer.FSQTokenizer,
fast_model_tokenizer_kwargs={
"fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer"
},
),
),
),
TrainConfig(
name="paligemma_diffusion_droid",
model=pi0.Pi0Config(action_horizon=10, action_dim=8),
data=SimpleDataConfig(
assets=AssetsConfig(asset_id="droid"),
data_transforms=lambda model: _transforms.Group(
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
outputs=[droid_policy.DroidOutputs()],
),
base_config=DataConfig(
prompt_from_task=True,
),
),
),
#
# Fine-tuning Libero configs.
#
# These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
# They are used to define key elements like the dataset you are training on, the base checkpoint you
# are using, and other hyperparameters like how many training steps to run or what learning rate to use.
# For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
# the comments below.
TrainConfig(
# Change the name to reflect your model and dataset.
name="pi0_libero",
# Here you define the model config -- In this example we use pi0 as the model
# architecture and perform *full* finetuning. in the examples below we show how to modify
# this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
model=pi0.Pi0Config(),
# Here you define the dataset you are training on. In this example we use the Libero
# dataset. For your own dataset, you can change the repo_id to point to your dataset.
# Also modify the DataConfig to use the new config you made for your dataset above.
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
# a field called ``prompt`` in the input dict. The recommended setting is True.
prompt_from_task=True,
),
),
# Here you define which pre-trained checkpoint you want to load to initialize the model.
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
# Check the base TrainConfig class for a full list of available hyperparameters.
num_train_steps=30_000,
),
TrainConfig(
name="pi0_libero_low_mem_finetune",
# Here is an example of loading a pi0 model for LoRA fine-tuning.
model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
@@ -467,13 +658,28 @@ _CONFIGS = [
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
# The freeze filter defines which parameters should be frozen during training.
# We have a convenience function in the model config that returns the default freeze filter
# for the given model config for LoRA finetuning. Just make sure it matches the model config
# you chose above.
freeze_filter=pi0.Pi0Config(
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
).get_freeze_filter(),
# Turn off EMA for LoRA finetuning.
ema_decay=None,
),
TrainConfig(
name="pi0_fast_libero",
# Here is an example of loading a pi0-FAST model for full finetuning.
# Modify action_dim and action_horizon to match your dataset (action horizon is equal to
# the desired action chunk length).
# The max_token_len is the maximum number of (non-image) tokens the model can handle.
# This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
# Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
# a warning), while choosing it too large will waste memory (since we pad each batch element to the
# max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
# two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
# you see many warnings being thrown during training.
model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
@@ -482,8 +688,33 @@ _CONFIGS = [
prompt_from_task=True,
),
),
# Note that we load the pi0-FAST base model checkpoint here.
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000,
),
TrainConfig(
name="pi0_fast_libero_low_mem_finetune",
# Here is an example of loading a pi0-FAST model for LoRA finetuning.
# For setting action_dim, action_horizon, and max_token_len, see the comments above.
model=pi0_fast.Pi0FASTConfig(
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
),
data=LeRobotLiberoDataConfig(
repo_id="physical-intelligence/libero",
base_config=DataConfig(
local_files_only=False, # Set to True for local-only datasets.
prompt_from_task=True,
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=30_000,
# Again, make sure to match the model config above when extracting the freeze filter
# that specifies which parameters should be frozen during LoRA finetuning.
freeze_filter=pi0_fast.Pi0FASTConfig(
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
# Turn off EMA for LoRA finetuning.
ema_decay=None,
),
#
# Fine-tuning Aloha configs.

View File

@@ -214,7 +214,12 @@ class TorchDataLoader:
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
if sharding is None:
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
# Use data parallel sharding by default.
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ("B",)),
jax.sharding.PartitionSpec("B"),
)
self._sharding = sharding
self._num_batches = num_batches