Compare commits
51 Commits
02-05-defa
...
karl/droid
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b84cc75031 | ||
|
|
c23bc86a0a | ||
|
|
fe5d5580a4 | ||
|
|
650b02e4ca | ||
|
|
e43516e719 | ||
|
|
20d63d47b7 | ||
|
|
1ce9ffe134 | ||
|
|
36dc3c037e | ||
|
|
bb7a3b4a3e | ||
|
|
f37e6dd7fb | ||
|
|
eb28153241 | ||
|
|
16affa3bee | ||
|
|
581e07d73a | ||
|
|
6c514a6d8a | ||
|
|
92b1082442 | ||
|
|
f1b9f4ab71 | ||
|
|
a4b1bf92f1 | ||
|
|
31289dbd72 | ||
|
|
cd0e9a2e0e | ||
|
|
620a56a399 | ||
|
|
42e4838aca | ||
|
|
3409be890e | ||
|
|
d139c700e4 | ||
|
|
d0b6231bd3 | ||
|
|
4a10482dfb | ||
|
|
bf25a4d9c4 | ||
|
|
7dccd73b37 | ||
|
|
29068dd274 | ||
|
|
ba68b3d97b | ||
|
|
cd82848a99 | ||
|
|
8d288e4b41 | ||
|
|
90b87cc42c | ||
|
|
0a67d46b0d | ||
|
|
80d346ea0d | ||
|
|
16788f847e | ||
|
|
f00207b91c | ||
|
|
b7c8bf24d4 | ||
|
|
ed05e55074 | ||
|
|
007e2b91ed | ||
|
|
fa5cf91df1 | ||
|
|
2a13ed7eff | ||
|
|
9c1376bcc1 | ||
|
|
9675e12c4e | ||
|
|
bf30fa3d4c | ||
|
|
f543cb1d87 | ||
|
|
6104624aca | ||
|
|
06cdf3a27f | ||
|
|
f8ce5c9479 | ||
|
|
153e34cefe | ||
|
|
f61cd24a15 | ||
|
|
ed11c29742 |
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
jobs:
|
||||
run_tests:
|
||||
name: Run Tests
|
||||
runs-on: verylarge
|
||||
runs-on: openpi-verylarge
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
|
||||
4
.gitmodules
vendored
4
.gitmodules
vendored
@@ -1,6 +1,6 @@
|
||||
[submodule "third_party/aloha"]
|
||||
path = third_party/aloha
|
||||
url = git@github.com:Physical-Intelligence/aloha.git
|
||||
url = https://github.com/Physical-Intelligence/aloha.git
|
||||
[submodule "third_party/libero"]
|
||||
path = third_party/libero
|
||||
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
|
||||
url = https://github.com/Lifelong-Robot-Learning/LIBERO.git
|
||||
|
||||
@@ -38,6 +38,7 @@ We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [
|
||||
|
||||
```bash
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv pip install -e .
|
||||
```
|
||||
|
||||
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
||||
@@ -66,7 +67,7 @@ We also provide "expert" checkpoints for various robot platforms and tasks. Thes
|
||||
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `s3://openpi-assets/checkpoints/pi0_droid` |
|
||||
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `s3://openpi-assets/checkpoints/pi0_aloha_towel` |
|
||||
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `s3://openpi-assets/checkpoints/pi0_aloha_tupperware` |
|
||||
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](XXX), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
||||
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](https://dit-policy.github.io/), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
||||
|
||||
|
||||
By default, checkpoints are automatically downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
||||
@@ -148,6 +149,8 @@ XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp
|
||||
|
||||
The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
|
||||
|
||||
**Note:** We provide functionality for *reloading* normalization statistics for state / action normalization from pre-training. This can be beneficial if you are fine-tuning to a new task on a robot that was part of our pre-training mixture. For more details on how to reload normalization statistics, see the [norm_stats.md](docs/norm_stats.md) file.
|
||||
|
||||
### 3. Spinning up a policy server and running inference
|
||||
|
||||
Once training is complete, we can run inference by spinning up a policy server and then querying it from a Libero evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
|
||||
@@ -158,12 +161,16 @@ uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero
|
||||
|
||||
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md).
|
||||
|
||||
If you want to embed a policy server call in your own robot runtime, we have a minimal example of how to do so in the [remote inference docs](docs/remote_inference.md).
|
||||
|
||||
|
||||
|
||||
### More Examples
|
||||
|
||||
We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
|
||||
- [ALOHA Simulator](examples/aloha_sim)
|
||||
- [ALOHA Real](examples/aloha_real)
|
||||
- [UR5](examples/ur5)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,24 @@
|
||||
|
||||
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
||||
|
||||
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
|
||||
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
|
||||
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
||||
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
|
||||
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
|
||||
|
||||
|
||||
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
|
||||
Build the Docker image and start the container with the following command:
|
||||
```bash
|
||||
docker compose -f scripts/docker/compose.yml up --build
|
||||
```
|
||||
|
||||
To build and run the Docker image for a specific example, use the following command:
|
||||
```bash
|
||||
docker compose -f examples/<example_name>/compose.yml up --build
|
||||
```
|
||||
where `<example_name>` is the name of the example you want to run.
|
||||
|
||||
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
||||
69
docs/norm_stats.md
Normal file
69
docs/norm_stats.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Normalization statistics
|
||||
|
||||
Following common practice, our models normalize the proprioceptive state inputs and action targets during policy training and inference. The statistics used for normalization are computed over the training data and stored alongside the model checkpoint.
|
||||
|
||||
## Reloading normalization statistics
|
||||
|
||||
When you fine-tune one of our models on a new dataset, you need to decide whether to (A) reuse existing normalization statistics or (B) compute new statistics over your new training data. Which option is better for you depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. Below, we list all the available pre-training normalization statistics for each model.
|
||||
|
||||
**If your target robot matches one of these pre-training statistics, consider reloading the same normalization statistics.** By reloading the normalization statistics, the actions in your dataset will be more "familiar" to the model, which can lead to better performance. You can reload the normalization statistics by adding an `AssetsConfig` to your training config that points to the corresponding checkpoint directory and normalization statistics ID, like below for the `Trossen` (aka ALOHA) robot statistics of the `pi0_base` checkpoint:
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
...
|
||||
data=LeRobotAlohaDataConfig(
|
||||
...
|
||||
assets=AssetsConfig(
|
||||
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="trossen",
|
||||
),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
For an example of a full training config that reloads normalization statistics, see the `pi0_aloha_pen_uncap` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
**Note:** To successfully reload normalization statistics, it's important that your robot + dataset are following the action space definitions used in pre-training. We provide a detailed description of our action space definitions below.
|
||||
|
||||
**Note #2:** Whether reloading normalization statistics is beneficial depends on the similarity of your robot and task to the robot and task distribution in the pre-training dataset. We recommend to always try both, reloading and training with a fresh set of statistics computed on your new dataset (see [main README](../README.md) for instructions on how to compute new statistics), and pick the one that works better for your task.
|
||||
|
||||
|
||||
## Provided Pre-training Normalization Statistics
|
||||
|
||||
Below is a list of all the pre-training normalization statistics we provide. We provide them for both, the `pi0_base` and `pi0_fast_base` models. For `pi0_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_base/assets` and for `pi0_fast_base`, set the `assets_dir` to `s3://openpi-assets/checkpoints/pi0_fast_base/assets`.
|
||||
| Robot | Description | Asset ID |
|
||||
|-------|-------------|----------|
|
||||
| ALOHA | 6-DoF dual arm robot with parallel grippers | trossen |
|
||||
| Mobile ALOHA | Mobile version of ALOHA mounted on a Slate base | trossen_mobile |
|
||||
| Franka Emika (DROID) | 7-DoF arm with parallel gripper based on the DROID setup | droid |
|
||||
| Franka Emika (non-DROID) | Franka FR3 arm with Robotiq 2F-85 gripper | franka |
|
||||
| UR5e | 6-DoF UR5e arm with Robotiq 2F-85 gripper | ur5e |
|
||||
| UR5e bi-manual | Bi-manual UR5e setup with Robotiq 2F-85 grippers | ur5e_dual |
|
||||
| ARX | Bi-manual ARX-5 robot arm setup with parallel gripper | arx |
|
||||
| ARX mobile | Mobile version of bi-manual ARX-5 robot arm setup mounted on a Slate base | arx_mobile |
|
||||
| Fibocom mobile | Fibocom mobile robot with 2x ARX-5 arms | fibocom_mobile |
|
||||
|
||||
|
||||
## Pi0 Model Action Space Definitions
|
||||
|
||||
Out of the box, both the `pi0_base` and `pi0_fast_base` use the following action space definitions (left and right are defined looking from behind the robot towards the workspace):
|
||||
```
|
||||
"dim_0:dim_5": "left arm joint angles",
|
||||
"dim_6": "left arm gripper position",
|
||||
"dim_7:dim_12": "right arm joint angles (for bi-manual only)",
|
||||
"dim_13": "right arm gripper position (for bi-manual only)",
|
||||
|
||||
# For mobile robots:
|
||||
"dim_14:dim_15": "x-y base velocity (for mobile robots only)",
|
||||
```
|
||||
|
||||
The proprioceptive state uses the same definitions as the action space, except for the base x-y position (the last two dimensions) for mobile robots, which we don't include in the proprioceptive state.
|
||||
|
||||
For 7-DoF robots (e.g. Franka), we use the first 7 dimensions of the action space for the joint actions, and the 8th dimension for the gripper action.
|
||||
|
||||
General info for Pi robots:
|
||||
- Joint angles are expressed in radians, with position zero corresponding to the zero position reported by each robot's interface library, except for ALOHA, where the standard ALOHA code uses a slightly different convention (see the [ALOHA example code](../examples/aloha_real/README.md) for details).
|
||||
- Gripper positions are in [0.0, 1.0], with 0.0 corresponding to fully open and 1.0 corresponding to fully closed.
|
||||
- Control frequencies are either 20 Hz for UR5e and Franka, and 50 Hz for ARX and Trossen (ALOHA) arms.
|
||||
|
||||
For DROID, we use the original DROID action configuration, with joint velocity actions in the first 7 dimensions and gripper actions in the 8th dimension + a control frequency of 15 Hz.
|
||||
@@ -33,10 +33,39 @@ pip install -e .
|
||||
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
||||
|
||||
```python
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
|
||||
action_chunk = policy_client.infer(example)["actions"]
|
||||
# Outside of episode loop, initialize the policy client.
|
||||
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
|
||||
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Inside the episode loop, construct the observation.
|
||||
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
|
||||
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
|
||||
# The typical resize_size for pre-trained pi0 models is 224.
|
||||
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
|
||||
observation = {
|
||||
"observation/image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, 224, 224)
|
||||
),
|
||||
"observation/wrist_image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, 224, 224)
|
||||
),
|
||||
"observation/state": state,
|
||||
"prompt": task_instruction,
|
||||
}
|
||||
|
||||
# Call the policy server with the current observation.
|
||||
# This returns an action chunk of shape (action_horizon, action_dim).
|
||||
# Note that you typically only need to call the policy every N steps and execute steps
|
||||
# from the predicted action chunk open-loop in the remaining steps.
|
||||
action_chunk = client.infer(observation)["actions"]
|
||||
|
||||
# Execute the actions in the environment.
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
@@ -28,13 +28,13 @@ uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
python -m examples.aloha_real.main
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
roslaunch aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
@@ -123,4 +123,4 @@ This task involves opening a tupperware filled with food and pouring the content
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
|
||||
@@ -155,7 +155,7 @@ def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, n
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
|
||||
@@ -27,7 +27,7 @@ uv run scripts/serve_policy.py --env=DROID
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
|
||||
import time
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
@@ -19,6 +19,9 @@ import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
# DROID data collection frequency -- we slow down execution to match this frequency
|
||||
DROID_CONTROL_FREQUENCY = 15
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
@@ -95,6 +98,7 @@ def main(args: Args):
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
@@ -145,6 +149,11 @@ def main(args: Args):
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
|
||||
# Sleep to match DROID data collection frequency
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
||||
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
@@ -27,4 +27,4 @@ Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
```
|
||||
|
||||
151
examples/ur5/README.md
Normal file
151
examples/ur5/README.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# UR5 Example
|
||||
|
||||
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
|
||||
|
||||
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Inputs(transforms.DataTransformFn):
|
||||
|
||||
action_dim: int
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
mask_padding = self.model_type == _model.ModelType.PI0
|
||||
|
||||
# First, concatenate the joints and gripper into the state vector.
|
||||
# Pad to the expected input dimensionality of the model (same as action_dim).
|
||||
state = np.concatenate([data["joints"], data["gripper"]])
|
||||
state = transforms.pad_to_dim(state, self.action_dim)
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
base_image = _parse_image(data["base_rgb"])
|
||||
wrist_image = _parse_image(data["wrist_rgb"])
|
||||
|
||||
# Create inputs dict.
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Since there is no right wrist, replace with zeros
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# Since the "slot" for the right wrist is not used, this mask is set
|
||||
# to False
|
||||
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
|
||||
},
|
||||
}
|
||||
|
||||
# Pad actions to the model action dimension.
|
||||
if "actions" in data:
|
||||
# The robot produces 7D actions (6 DoF + 1 gripper), and we pad these.
|
||||
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
|
||||
inputs["actions"] = actions
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Outputs(transforms.DataTransformFn):
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
|
||||
```
|
||||
|
||||
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LeRobotUR5DataConfig(DataConfigFactory):
|
||||
|
||||
@override
|
||||
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
||||
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
|
||||
repack_transform = _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
{
|
||||
"base_rgb": "image",
|
||||
"wrist_rgb": "wrist_image",
|
||||
"joints": "joints",
|
||||
"gripper": "gripper",
|
||||
"prompt": "prompt",
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# These transforms are the ones we wrote earlier.
|
||||
data_transforms = _transforms.Group(
|
||||
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
||||
outputs=[UR5Outputs()],
|
||||
)
|
||||
|
||||
# Convert absolute actions to delta actions.
|
||||
# By convention, we do not convert the gripper action (7th dimension).
|
||||
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
||||
data_transforms = data_transforms.push(
|
||||
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
||||
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
||||
)
|
||||
|
||||
# Model transforms include things like tokenizing the prompt and action targets
|
||||
# You do not need to change anything here for your own dataset.
|
||||
model_transforms = ModelTransformFactory()(model_config)
|
||||
|
||||
# We return all data transforms for training and inference. No need to change anything here.
|
||||
return dataclasses.replace(
|
||||
self.create_base_config(assets_dirs),
|
||||
repack_transforms=repack_transform,
|
||||
data_transforms=data_transforms,
|
||||
model_transforms=model_transforms,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name="pi0_ur5",
|
||||
model=pi0.Pi0Config(),
|
||||
data=LeRobotUR5DataConfig(
|
||||
repo_id="your_username/ur5_dataset",
|
||||
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
|
||||
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
||||
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
||||
assets=AssetsConfig(
|
||||
assets_dir="s3://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="ur5e",
|
||||
),
|
||||
base_config=DataConfig(
|
||||
local_files_only=True, # True, if dataset is saved locally.
|
||||
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
||||
# ``task`` field in the LeRobot dataset. The recommended setting is True.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Load the pi0 base model checkpoint.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
# docker compose -f scripts/docker/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
|
||||
@@ -199,7 +199,6 @@ def main(config: _config.TrainConfig):
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
|
||||
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
466
src/openpi/models/fsq_tokenizer_v2.py
Normal file
@@ -0,0 +1,466 @@
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import chex
|
||||
from einops import einops
|
||||
from flax import linen as nn
|
||||
from flax.linen.module import Module
|
||||
from flax.linen.module import compact
|
||||
from flax.struct import dataclass
|
||||
from flax.typing import Array
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class FsqCodebook(nn.Module):
|
||||
input_dim: int
|
||||
target_codebook_size: int
|
||||
codebook_type: Literal["fsq", "lfq"]
|
||||
|
||||
_bins_per_dim: tuple[int] | None = None
|
||||
|
||||
@property
|
||||
def bins_per_dim(self):
|
||||
if self._bins_per_dim is not None:
|
||||
return self._bins_per_dim
|
||||
|
||||
if self.codebook_type == "fsq":
|
||||
return self._get_bins_fsq(self.target_codebook_size)
|
||||
elif self.codebook_type == "lfq": # noqa: RET505
|
||||
return self._get_bins_lfq(self.target_codebook_size)
|
||||
elif self.codebook_type == "custom":
|
||||
return self._get_bins_custom(self.target_codebook_size)
|
||||
else:
|
||||
raise ValueError(f"Codebook type {self.codebook_type} not supported.")
|
||||
|
||||
@property
|
||||
def place_values(self):
|
||||
place_values = [1]
|
||||
for b in self.bins_per_dim[:-1]:
|
||||
place_values.append(place_values[-1] * b)
|
||||
return jnp.array(place_values)
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_fsq(target_codebook_size):
|
||||
"""
|
||||
Get bins per dimension based on codebook size, from the original FSQ paper.
|
||||
"""
|
||||
if target_codebook_size == 2**8:
|
||||
return (8, 6, 5)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (8, 5, 5, 5)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (7, 5, 5, 5, 5)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (8, 8, 8, 6, 5)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (8, 8, 8, 5, 5, 5)
|
||||
else:
|
||||
raise ValueError(f"Codebook size {target_codebook_size} not supported.")
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_custom(target_codebook_size):
|
||||
if target_codebook_size == 2**8:
|
||||
return (16, 16)
|
||||
elif target_codebook_size == 2**10: # noqa: RET505
|
||||
return (32, 32)
|
||||
elif target_codebook_size == 2**12:
|
||||
return (64, 64)
|
||||
elif target_codebook_size == 2**14:
|
||||
return (128, 128)
|
||||
elif target_codebook_size == 2**16:
|
||||
return (256, 256)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_bins_lfq(target_codebook_size):
|
||||
"""
|
||||
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension)
|
||||
"""
|
||||
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ"
|
||||
|
||||
return (2,) * int(math.log2(target_codebook_size))
|
||||
|
||||
def setup(self):
|
||||
self.proj_down = nn.Dense(len(self.bins_per_dim))
|
||||
self.proj_up = nn.Dense(self.input_dim)
|
||||
|
||||
def __call__(self, inputs):
|
||||
tokens, z = self.encode(inputs)
|
||||
output = self.decode(tokens, z_grad=z)
|
||||
return tokens, output
|
||||
|
||||
def encode(self, inputs):
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
|
||||
x = self.proj_down(inputs)
|
||||
z = jnp.tanh(x)
|
||||
|
||||
# Quantize
|
||||
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32)
|
||||
tokens = self.undigitize(digits)
|
||||
|
||||
return tokens, z
|
||||
|
||||
def decode(self, tokens, z_grad: jax.Array | None = None):
|
||||
bases = jnp.array(self.bins_per_dim)
|
||||
digits = self.digitize(tokens)
|
||||
|
||||
z_q = digits / (bases - 1) * 2 - 1
|
||||
|
||||
if z_grad is not None:
|
||||
chex.assert_equal_shape([z_q, z_grad])
|
||||
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad
|
||||
|
||||
return self.proj_up(z_q)
|
||||
|
||||
def undigitize(self, digits):
|
||||
return jnp.sum(digits * jnp.array(self.place_values), axis=-1)
|
||||
|
||||
def digitize(self, tokens):
|
||||
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim)
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return math.prod(self.bins_per_dim)
|
||||
|
||||
|
||||
class ResNetDownBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train=True):
|
||||
skip = x
|
||||
|
||||
if self.stride > 1 or x.shape[-1] != self.n_filters:
|
||||
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
class ResNetUpBlock(nn.Module):
|
||||
stride: int = 1
|
||||
n_filters: int = 64
|
||||
dropout_rate: float = 0.0
|
||||
group_size: int = 32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train=True):
|
||||
skip = x
|
||||
|
||||
if self.stride > 1:
|
||||
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip)
|
||||
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x)
|
||||
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = nn.relu(x)
|
||||
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x)
|
||||
|
||||
return skip + x
|
||||
|
||||
|
||||
@dataclass
|
||||
class LfqCodebookOutput:
|
||||
tokens: jnp.ndarray
|
||||
z: jnp.ndarray
|
||||
z_q: jnp.ndarray
|
||||
token_log_probs: jnp.ndarray
|
||||
commit_loss: jnp.ndarray
|
||||
|
||||
|
||||
class LookupFreeQuantization(nn.Module):
|
||||
num_dims: int
|
||||
latent_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.codebook = jnp.array([-1, 1])
|
||||
# self.activation = lambda x: x
|
||||
self.activation = nn.tanh
|
||||
|
||||
self.project_down = nn.Dense(self.num_dims)
|
||||
self.project_up = nn.Dense(self.latent_dim)
|
||||
|
||||
def encode(self, z):
|
||||
z = self.project_down(z)
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
token_bits = jnp.argmin(token_squared_distances, axis=-1)
|
||||
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1)
|
||||
|
||||
def decode(self, tokens):
|
||||
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32)
|
||||
return self.project_up(self.codebook[token_bits])
|
||||
|
||||
def loss(self, x):
|
||||
z = self.project_down(x)
|
||||
z = self.activation(z)
|
||||
|
||||
token_squared_distances = jnp.square(z[..., None] - self.codebook)
|
||||
tokens = jnp.argmin(token_squared_distances, axis=-1)
|
||||
|
||||
token_bit_log_probs = -token_squared_distances # jax.nn.log_softmax(-token_squared_distances, axis=-1)
|
||||
# Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs
|
||||
token_bit_expansions = jnp.bitwise_and(
|
||||
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None]
|
||||
).astype(jnp.int32)
|
||||
token_log_probs = (
|
||||
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions)
|
||||
+ token_bit_log_probs[..., 1] @ token_bit_expansions
|
||||
) # (batch_size, num_tokens, 2 ** num_dims)
|
||||
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1))
|
||||
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims))
|
||||
|
||||
z_q = self.codebook[tokens]
|
||||
commit_loss = jnp.square(z - z_q).mean()
|
||||
z_q = jax.lax.stop_gradient(z_q - z) + z
|
||||
|
||||
z_q = self.project_up(z_q)
|
||||
z = self.project_up(z)
|
||||
|
||||
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1)
|
||||
return LfqCodebookOutput(
|
||||
tokens=tokens,
|
||||
z=z,
|
||||
z_q=z_q,
|
||||
token_log_probs=jnp.zeros(()),
|
||||
commit_loss=commit_loss,
|
||||
)
|
||||
|
||||
|
||||
def make_block_causal_attention_matrix(q, k, bs_q, bs_k):
|
||||
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q))
|
||||
|
||||
|
||||
class GeGLU(Module):
|
||||
"""Gated Linear Unit with GELU (GeGLU) activation function.
|
||||
GeGLU is a Flax layer that combines a linear transformation with a GELU
|
||||
activation function in a gating mechanism. It is often used in Transformer models
|
||||
to provide non-linear capabilities while preserving a strong linear component.
|
||||
Example usage::
|
||||
>>> import flax.linen as nn
|
||||
>>> class TransformerBlock(nn.Module):
|
||||
... @nn.compact
|
||||
... def __call__(self, x):
|
||||
... x = nn.Dense(2)(x)
|
||||
... x = nn.GeGLU()(x) # initialized
|
||||
... return x
|
||||
Attributes:
|
||||
features: the number of output features (default: None).
|
||||
"""
|
||||
|
||||
output_dim: int = -1
|
||||
|
||||
@compact
|
||||
def __call__(self, inputs: Array) -> Array:
|
||||
"""Applies the GeGLU activation to the inputs.
|
||||
Args:
|
||||
inputs: the nd-array to apply the GeGLU activation function to.
|
||||
Returns:
|
||||
The transformed input.
|
||||
"""
|
||||
if self.output_dim == -1:
|
||||
output_dim = inputs.shape[-1]
|
||||
else:
|
||||
output_dim = self.output_dim
|
||||
|
||||
x = nn.Dense(output_dim * 2)(inputs)
|
||||
x, gate = x[..., :output_dim], x[..., output_dim:]
|
||||
return x * nn.gelu(gate)
|
||||
|
||||
|
||||
class CrossAttentionLayer(nn.Module):
|
||||
dropout_rate: float = 0.0
|
||||
num_heads: int = None
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 4.0
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, y, *, mask_self=None, mask_cross=None, train=True):
|
||||
d_embed = x.shape[-1]
|
||||
seq_len_q = x.shape[-2]
|
||||
seq_len_k = y.shape[-2]
|
||||
|
||||
if self.causal:
|
||||
# One block size will be 1
|
||||
bs_q = max(seq_len_q // seq_len_k, 1)
|
||||
bs_k = max(seq_len_k // seq_len_q, 1)
|
||||
|
||||
mask_self = nn.make_causal_mask(x[..., 0])
|
||||
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k)
|
||||
|
||||
# Self-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
)(x, x, x, mask=mask_self)
|
||||
x = skip + x
|
||||
|
||||
# Cross-attention block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
# bias = -jnp.abs(jnp.linspace(0, 1, seq_len_q)[:, None] - jnp.linspace(0, 1, seq_len_k)) * 5
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads or d_embed // 64,
|
||||
dropout_rate=self.dropout_rate,
|
||||
deterministic=not train,
|
||||
# attention_fn=partial(nn.dot_product_attention, bias=bias),
|
||||
)(x, y, y, mask=mask_cross)
|
||||
x = skip + x
|
||||
|
||||
# MLP block
|
||||
skip = x
|
||||
x = nn.LayerNorm()(x)
|
||||
x = nn.Dense(int(d_embed * self.mlp_ratio))(x)
|
||||
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train)
|
||||
x = GeGLU()(x)
|
||||
x = nn.Dense(d_embed)(x)
|
||||
return skip + x
|
||||
|
||||
|
||||
def sinusoidal_pe_init(_, shape):
|
||||
seq_len, d_embed = shape
|
||||
|
||||
position = jnp.arange(0, seq_len, 1)
|
||||
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed))
|
||||
return jnp.concatenate(
|
||||
[
|
||||
jnp.sin(position[:, jnp.newaxis] * div_term),
|
||||
jnp.cos(position[:, jnp.newaxis] * div_term),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
|
||||
class TokenizerEncoderDecoder(nn.Module):
|
||||
num_tokens: int
|
||||
num_cross_tokens: int
|
||||
num_layers: int
|
||||
causal: bool
|
||||
|
||||
mlp_ratio: float = 4.0
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, y, *, train=True, state_conditioning=None, mask=None):
|
||||
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1]))
|
||||
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:])
|
||||
|
||||
if mask is not None:
|
||||
# mask is (batch_dims..., num_cross_tokens)
|
||||
chex.assert_equal_shape([y[..., 0], mask])
|
||||
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens)
|
||||
else:
|
||||
attn_mask = jnp.ones(y.shape[:-2] + (1, self.num_tokens, self.num_cross_tokens))
|
||||
|
||||
if self.use_state_conditioning:
|
||||
assert state_conditioning is not None, "State conditioning is required for this model."
|
||||
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :]
|
||||
y = jnp.concatenate([y, state_embed], axis=-2)
|
||||
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1)
|
||||
|
||||
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:])
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)(
|
||||
x, y, train=train, mask_self=None, mask_cross=attn_mask
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FsqAttentionTokenizer(nn.Module):
|
||||
embed_dim: int
|
||||
data_dim: int
|
||||
data_horizon: int
|
||||
num_tokens: int
|
||||
num_layers: int
|
||||
target_codebook_size: int
|
||||
causal: bool = False
|
||||
mlp_ratio: float = 2.0
|
||||
|
||||
bound: float | None = None
|
||||
|
||||
use_state_conditioning: bool = False
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size))
|
||||
|
||||
def setup(self):
|
||||
self.proj = nn.Dense(self.embed_dim)
|
||||
self.encoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.num_tokens,
|
||||
num_cross_tokens=self.data_horizon,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
self.codebook = FsqCodebook(
|
||||
input_dim=self.embed_dim,
|
||||
target_codebook_size=self.target_codebook_size,
|
||||
codebook_type="custom",
|
||||
)
|
||||
self.decoder = TokenizerEncoderDecoder(
|
||||
num_tokens=self.data_horizon,
|
||||
num_cross_tokens=self.num_tokens,
|
||||
num_layers=self.num_layers,
|
||||
causal=self.causal,
|
||||
use_state_conditioning=self.use_state_conditioning,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
)
|
||||
|
||||
self.proj_mean = nn.Dense(self.data_dim)
|
||||
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0))
|
||||
|
||||
def tokenize(self, action, *, obs=None, train=False):
|
||||
if self.bound is not None:
|
||||
action = jnp.clip(action, -self.bound, self.bound)
|
||||
|
||||
x = self.proj(action)
|
||||
x = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
return self.codebook.encode(x)
|
||||
|
||||
def detokenize(self, tokens, *, obs=None):
|
||||
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs)
|
||||
mean = self.proj_mean(x)
|
||||
return mean * self.out_scale
|
||||
|
||||
def loss(self, action, *, obs=None, train=True):
|
||||
# Encode
|
||||
x = self.proj(action)
|
||||
z = self.encoder(x, train=train, state_conditioning=obs)
|
||||
|
||||
# Quantize
|
||||
tokens, z = self.codebook(z)
|
||||
|
||||
# Decode
|
||||
x = self.decoder(z, train=train, state_conditioning=obs)
|
||||
mean = self.proj_mean(x) * self.out_scale
|
||||
|
||||
mse = jnp.mean(jnp.square(action - mean))
|
||||
mae = jnp.mean(jnp.abs(action - mean))
|
||||
|
||||
return mse, {
|
||||
"mse": mse,
|
||||
"mae": mae,
|
||||
}
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
Dummy for .init
|
||||
"""
|
||||
return self.loss(*args, **kwargs)
|
||||
@@ -17,6 +17,7 @@ Gemma model implementation from big_vision/models/ppp/gemma.py (with small modif
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
@@ -25,9 +26,10 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b"]
|
||||
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
@@ -48,6 +50,26 @@ def get_config(variant):
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
"lora_configs": {
|
||||
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
},
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@@ -110,21 +132,34 @@ class Attention(nn.Module):
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
lora_config: lora.LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = Einsum(
|
||||
self.qkv_einsum = lora.Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
name="qkv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
else:
|
||||
# MQA
|
||||
self.q_einsum = Einsum(
|
||||
self.q_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
name="q_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.kv_einsum = Einsum(
|
||||
self.kv_einsum = lora.Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
name="kv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.attn_vec_einsum = Einsum(
|
||||
self.attn_vec_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
name="attn_vec_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
@@ -189,37 +224,6 @@ class Attention(nn.Module):
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.zeros_init(),
|
||||
((2, self.features, self.hidden_dim)),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
@@ -233,6 +237,7 @@ class Block(nn.Module):
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
@@ -242,9 +247,12 @@ class Block(nn.Module):
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
lora_config=self.lora_configs.get("attn"),
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
||||
self.mlp = lora.FeedForward(
|
||||
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||
)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
@@ -289,6 +297,7 @@ class Module(nn.Module):
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
@@ -380,6 +389,7 @@ class Module(nn.Module):
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
"lora_configs": self.lora_configs,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from flax import nnx
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
@@ -53,6 +54,27 @@ def test_pi0_fast_model():
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
|
||||
def test_pi0_fast_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
||||
model_state = nnx.state(model)
|
||||
|
||||
lora_state_elems = list(model_state.filter(lora_filter))
|
||||
assert len(lora_state_elems) > 0
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_model_restore():
|
||||
key = jax.random.key(0)
|
||||
|
||||
@@ -12,6 +12,7 @@ from openpi.models import model as _model
|
||||
import openpi.models.gemma_fast as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
@@ -117,6 +118,12 @@ class Pi0FASTConfig(_model.BaseModelConfig):
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||
"""Returns the freeze filter based on the model config."""
|
||||
if "lora" in self.paligemma_variant:
|
||||
return nnx.All(nnx_utils.PathRegex(".*llm.*"), nnx.Not(nnx_utils.PathRegex(".*lora.*")))
|
||||
return nnx.Nothing
|
||||
|
||||
|
||||
class Pi0FAST(_model.BaseModel):
|
||||
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
@@ -125,3 +126,215 @@ class FASTTokenizer:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
class BinningTokenizer:
|
||||
def __init__(self, max_len: int = 256, n_bins: int = 256):
|
||||
self._max_len = max_len
|
||||
self._n_bins = n_bins
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("BinningTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
if len(action_tokens) < action_horizon * action_dim:
|
||||
return np.zeros([action_horizon, action_dim], dtype=np.float32)
|
||||
action_tokens = action_tokens[: (action_horizon * action_dim)].reshape([action_horizon, action_dim])
|
||||
return action_tokens / self._n_bins * 2 - 1
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
|
||||
class FSQTokenizer:
|
||||
def __init__(self, max_len: int = 256, fsq_tokenizer_path: str | None = None):
|
||||
import jax
|
||||
import orbax.checkpoint as ocp
|
||||
|
||||
import openpi.models.fsq_tokenizer_v2 as fsq_tokenizer
|
||||
|
||||
self._max_len = max_len
|
||||
|
||||
assert fsq_tokenizer_path is not None, "fsq_tokenizer_path must be provided"
|
||||
# Download tokenizer
|
||||
path = download.maybe_download(fsq_tokenizer_path)
|
||||
tok_path = os.path.join(path, os.listdir(path)[0])
|
||||
|
||||
# Split step from path
|
||||
step = int(tok_path.split("/")[-1])
|
||||
base_path = tok_path.rsplit("/", 1)[0]
|
||||
|
||||
mgr = ocp.CheckpointManager(
|
||||
base_path,
|
||||
item_handlers={
|
||||
"params": ocp.StandardCheckpointHandler(),
|
||||
"opt_state": ocp.StandardCheckpointHandler(),
|
||||
"config": ocp.JsonCheckpointHandler(),
|
||||
},
|
||||
options=ocp.CheckpointManagerOptions(max_to_keep=1),
|
||||
)
|
||||
|
||||
try:
|
||||
restored = mgr.restore(
|
||||
step, args=ocp.args.Composite(config=ocp.args.JsonRestore(), params=ocp.args.StandardRestore())
|
||||
)
|
||||
config = restored["config"]
|
||||
self._params = restored["params"]
|
||||
self._fsq_tokenizer = fsq_tokenizer.FsqAttentionTokenizer(**config)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to load FSQ tokenizer checkpoint from {fsq_tokenizer_path}. Error: {e!s}"
|
||||
) from e
|
||||
|
||||
# Compile tokenize and detokenize functions
|
||||
self._tokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.tokenize)
|
||||
)
|
||||
self._detokenize_fn = jax.jit(
|
||||
lambda params, x: self._fsq_tokenizer.apply({"params": params}, x, method=self._fsq_tokenizer.detokenize)
|
||||
)
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
raise NotImplementedError("FSQTokenizer does not support encoding actions atm (only for inference use)")
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
try:
|
||||
import jax
|
||||
|
||||
# Move computation to CPU and compile on-demand
|
||||
device = jax.devices("cpu")[0]
|
||||
with jax.default_device(device):
|
||||
detok_act = self._detokenize_fn(self._params, action_tokens[None, ...])[0]
|
||||
return detok_act[: action_horizon * action_dim].reshape([action_horizon, action_dim])
|
||||
except Exception as e:
|
||||
logging.warning(f"Error decoding FSQ: {e}")
|
||||
return np.zeros((action_horizon, action_dim))
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
|
||||
@@ -28,45 +28,72 @@ def _parse_image(image) -> np.ndarray:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoInputs(transforms.DataTransformFn):
|
||||
"""
|
||||
This class is used to convert inputs to the model to the expected format. It is used for both training and inference.
|
||||
|
||||
For your own dataset, you can copy this class and modify the keys based on the comments below to pipe
|
||||
the correct elements of your dataset into the model.
|
||||
"""
|
||||
|
||||
# The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
|
||||
# Do not change this for your own dataset.
|
||||
action_dim: int
|
||||
|
||||
# Determines which model will be used.
|
||||
# Do not change this for your own dataset.
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
|
||||
# We only mask padding for pi0 model, not pi0-FAST. Do not change this for your own dataset.
|
||||
mask_padding = self.model_type == _model.ModelType.PI0
|
||||
|
||||
# Get the state. We are padding from 8 to the model action dim.
|
||||
# For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
|
||||
# We pad the proprioceptive input to the action dimension of the model.
|
||||
# For pi0-FAST, we don't pad the state. For Libero, we don't need to differentiate
|
||||
# since the pi0-FAST action_dim = 7, which is < state_dim = 8, so pad is skipped.
|
||||
# Keep this for your own dataset, but if your dataset stores the proprioceptive input
|
||||
# in a different key than "observation/state", you should change it below.
|
||||
state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
# Keep this for your own dataset, but if your dataset stores the images
|
||||
# in a different key than "observation/image" or "observation/wrist_image",
|
||||
# you should change it below.
|
||||
# Pi0 models support three image inputs at the moment: one third-person view,
|
||||
# and two wrist views (left and right). If your dataset does not have a particular type
|
||||
# of image, e.g. wrist images, you can comment it out here and replace it with zeros like we do for the
|
||||
# right wrist image below.
|
||||
base_image = _parse_image(data["observation/image"])
|
||||
wrist_image = _parse_image(data["observation/wrist_image"])
|
||||
|
||||
# Create inputs dict. Do not change the keys in the dict below.
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Pad any non-existent images with zero-arrays of the appropriate shape.
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# Mask any non-existent images with False (if ``mask_padding`` is True).
|
||||
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
|
||||
},
|
||||
}
|
||||
|
||||
# Pad actions to the model action dimension. Keep this for your own dataset.
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
# We are padding from 7 to the model action dim.
|
||||
# We are padding to the model action dim.
|
||||
# For pi0-FAST, this is a no-op (since action_dim = 7).
|
||||
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
|
||||
inputs["actions"] = actions
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
# Keep this for your own dataset (but modify the key if the instruction is not
|
||||
# stored in "prompt"; the output dict always needs to have the key "prompt").
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
@@ -75,6 +102,16 @@ class LiberoInputs(transforms.DataTransformFn):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoOutputs(transforms.DataTransformFn):
|
||||
"""
|
||||
This class is used to convert outputs from the model back the the dataset specific format. It is
|
||||
used for inference only.
|
||||
|
||||
For your own dataset, you can copy this class and modify the action dimension based on the comments below.
|
||||
"""
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 7 dims.
|
||||
# Only return the first N actions -- since we padded actions above to fit the model action
|
||||
# dimension, we need to now parse out the correct number of actions in the return dict.
|
||||
# For Libero, we only return the first 7 actions (since the rest is padding).
|
||||
# For your own dataset, replace `7` with the action dimension of your dataset.
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
|
||||
@@ -37,7 +37,7 @@ def get_cache_dir() -> pathlib.Path:
|
||||
return cache_dir
|
||||
|
||||
|
||||
def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
||||
def maybe_download(url: str, *, force_download: bool = False, **kwargs) -> pathlib.Path:
|
||||
"""Download a file or directory from a remote filesystem to the local cache, and return the local path.
|
||||
|
||||
If the local file already exists, it will be returned directly.
|
||||
@@ -47,6 +47,7 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
||||
|
||||
Args:
|
||||
url: URL to the file to download.
|
||||
force_download: If True, the file will be downloaded even if it already exists in the cache.
|
||||
**kwargs: Additional arguments to pass to fsspec.
|
||||
|
||||
Returns:
|
||||
@@ -67,30 +68,56 @@ def maybe_download(url: str, **kwargs) -> pathlib.Path:
|
||||
local_path = cache_dir / parsed.netloc / parsed.path.strip("/")
|
||||
local_path = local_path.resolve()
|
||||
|
||||
# Check if file already exists in cache.
|
||||
if local_path.exists() and not _invalidate_expired_cache(cache_dir, local_path):
|
||||
return local_path
|
||||
|
||||
# Download file from remote file system.
|
||||
logger.info(f"Downloading {url} to {local_path}")
|
||||
with filelock.FileLock(local_path.with_suffix(".lock")):
|
||||
scratch_path = local_path.with_suffix(".partial")
|
||||
|
||||
if _is_openpi_url(url):
|
||||
# Download without credentials.
|
||||
_download_boto3(
|
||||
url,
|
||||
scratch_path,
|
||||
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
|
||||
)
|
||||
elif url.startswith("s3://"):
|
||||
# Download with default boto3 credentials.
|
||||
_download_boto3(url, scratch_path)
|
||||
# Check if the cache should be invalidated.
|
||||
invalidate_cache = False
|
||||
if local_path.exists():
|
||||
if force_download or _should_invalidate_cache(cache_dir, local_path):
|
||||
invalidate_cache = True
|
||||
else:
|
||||
_download_fsspec(url, scratch_path, **kwargs)
|
||||
return local_path
|
||||
|
||||
shutil.move(scratch_path, local_path)
|
||||
_ensure_permissions(local_path)
|
||||
try:
|
||||
lock_path = local_path.with_suffix(".lock")
|
||||
with filelock.FileLock(lock_path):
|
||||
# Ensure consistent permissions for the lock file.
|
||||
_ensure_permissions(lock_path)
|
||||
# First, remove the existing cache if it is expired.
|
||||
if invalidate_cache:
|
||||
logger.info(f"Removing expired cached entry: {local_path}")
|
||||
if local_path.is_dir():
|
||||
shutil.rmtree(local_path)
|
||||
else:
|
||||
local_path.unlink()
|
||||
|
||||
# Download the data to a local cache.
|
||||
logger.info(f"Downloading {url} to {local_path}")
|
||||
scratch_path = local_path.with_suffix(".partial")
|
||||
|
||||
if _is_openpi_url(url) or _is_openpi_simeval_url(url):
|
||||
# Download without credentials.
|
||||
_download_boto3(
|
||||
url,
|
||||
scratch_path,
|
||||
boto_session=boto3.Session(
|
||||
region_name="us-west-1",
|
||||
),
|
||||
botocore_config=botocore.config.Config(signature_version=botocore.UNSIGNED),
|
||||
)
|
||||
elif url.startswith("s3://"):
|
||||
# Download with default boto3 credentials.
|
||||
_download_boto3(url, scratch_path)
|
||||
else:
|
||||
_download_fsspec(url, scratch_path, **kwargs)
|
||||
|
||||
shutil.move(scratch_path, local_path)
|
||||
_ensure_permissions(local_path)
|
||||
|
||||
except PermissionError as e:
|
||||
msg = (
|
||||
f"Local file permission error was encountered while downloading {url}. "
|
||||
f"Please try again after removing the cached data using: `rm -rf {local_path}*`"
|
||||
)
|
||||
raise PermissionError(msg) from e
|
||||
|
||||
return local_path
|
||||
|
||||
@@ -209,7 +236,8 @@ def _download_boto3(
|
||||
def _get_s3_transfer_manager(
|
||||
session: boto3.Session, workers: int, botocore_config: botocore.config.Config | None = None
|
||||
) -> s3_transfer.TransferManager:
|
||||
config = botocore.config.Config(max_pool_connections=workers)
|
||||
# Add a few extra connections to prevent exceeding the pool size.
|
||||
config = botocore.config.Config(max_pool_connections=workers + 2)
|
||||
if botocore_config is not None:
|
||||
config = config.merge(botocore_config)
|
||||
s3client = session.client("s3", config=config)
|
||||
@@ -271,6 +299,9 @@ def _is_openpi_url(url: str) -> bool:
|
||||
"""Check if the url is an OpenPI S3 bucket url."""
|
||||
return url.startswith("s3://openpi-assets/")
|
||||
|
||||
def _is_openpi_simeval_url(url: str) -> bool:
|
||||
"""Check if the url is an OpenPI simeval S3 bucket url."""
|
||||
return url.startswith("s3://openpi-assets-simeval/")
|
||||
|
||||
def _get_mtime(year: int, month: int, day: int) -> float:
|
||||
"""Get the mtime of a given date at midnight UTC."""
|
||||
@@ -282,11 +313,13 @@ def _get_mtime(year: int, month: int, day: int) -> float:
|
||||
# Partial matching will be used from top to bottom and the first match will be chosen.
|
||||
# Cached entries will be retained only if they are newer than the expiration timestamp.
|
||||
_INVALIDATE_CACHE_DIRS: dict[re.Pattern, float] = {
|
||||
re.compile("openpi-assets/checkpoints/pi0_aloha_pen_uncap"): _get_mtime(2025, 2, 17),
|
||||
re.compile("openpi-assets/checkpoints/pi0_libero"): _get_mtime(2025, 2, 6),
|
||||
re.compile("openpi-assets/checkpoints/"): _get_mtime(2025, 2, 3),
|
||||
}
|
||||
|
||||
|
||||
def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
||||
def _should_invalidate_cache(cache_dir: pathlib.Path, local_path: pathlib.Path) -> bool:
|
||||
"""Invalidate the cache if it is expired. Return True if the cache was invalidated."""
|
||||
|
||||
assert local_path.exists(), f"File not found at {local_path}"
|
||||
@@ -295,13 +328,6 @@ def _invalidate_expired_cache(cache_dir: pathlib.Path, local_path: pathlib.Path)
|
||||
for pattern, expire_time in _INVALIDATE_CACHE_DIRS.items():
|
||||
if pattern.match(relative_path):
|
||||
# Remove if not newer than the expiration timestamp.
|
||||
if local_path.stat().st_mtime <= expire_time:
|
||||
logger.info(f"Removing expired cached entry: {local_path}")
|
||||
if local_path.is_dir():
|
||||
shutil.rmtree(local_path)
|
||||
else:
|
||||
local_path.unlink()
|
||||
return True
|
||||
return False
|
||||
return local_path.stat().st_mtime <= expire_time
|
||||
|
||||
return False
|
||||
|
||||
@@ -48,8 +48,8 @@ def initialize_checkpoint_dir(
|
||||
),
|
||||
)
|
||||
|
||||
# special case: the checkpoint directory exists and the user requests to resume training, but the training run did
|
||||
# not get to the first checkpoint saved. in this case, we don't actually want the train script to try and restore a
|
||||
# Special case: the checkpoint directory exists and the user requests to resume training, but the training run did
|
||||
# not get to the first checkpoint saved. In this case, we don't actually want the train script to try and restore a
|
||||
# checkpoint, since it will fail.
|
||||
if resuming and tuple(mngr.all_steps()) in [(), (0,)]:
|
||||
logging.info("Checkpoint directory exists, but does not contain any checkpoints. Aborting resume.")
|
||||
|
||||
@@ -102,6 +102,8 @@ class ModelTransformFactory(GroupFactory):
|
||||
|
||||
# If provided, will determine the default prompt that be used by the model.
|
||||
default_prompt: str | None = None
|
||||
fast_model_tokenizer: Any | None = None
|
||||
fast_model_tokenizer_kwargs: dict[str, Any] | None = None
|
||||
|
||||
def __call__(self, model_config: _model.BaseModelConfig) -> _transforms.Group:
|
||||
match model_config.model_type:
|
||||
@@ -116,17 +118,21 @@ class ModelTransformFactory(GroupFactory):
|
||||
],
|
||||
)
|
||||
case _model.ModelType.PI0_FAST:
|
||||
tokenizer_cls = (
|
||||
_tokenizer.FASTTokenizer if self.fast_model_tokenizer is None else self.fast_model_tokenizer
|
||||
)
|
||||
tokenizer_kwargs = {} if self.fast_model_tokenizer_kwargs is None else self.fast_model_tokenizer_kwargs
|
||||
return _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.InjectDefaultPrompt(self.default_prompt),
|
||||
_transforms.ResizeImages(224, 224),
|
||||
_transforms.TokenizeFASTInputs(
|
||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
||||
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
_transforms.ExtractFASTActions(
|
||||
_tokenizer.FASTTokenizer(model_config.max_token_len),
|
||||
tokenizer_cls(model_config.max_token_len, **tokenizer_kwargs),
|
||||
action_horizon=model_config.action_horizon,
|
||||
action_dim=model_config.action_dim,
|
||||
)
|
||||
@@ -251,9 +257,22 @@ class LeRobotAlohaDataConfig(DataConfigFactory):
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LeRobotLiberoDataConfig(DataConfigFactory):
|
||||
"""
|
||||
This config is used to configure transforms that are applied at various parts of the data pipeline.
|
||||
For your own dataset, you can copy this class and modify the transforms to match your dataset based on the
|
||||
comments below.
|
||||
"""
|
||||
|
||||
@override
|
||||
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
||||
# Make inputs look like they come from the Libero environment
|
||||
# The repack transform is *only* applied to the data coming from the dataset,
|
||||
# and *not* during inference. We can use it to make inputs from the dataset look
|
||||
# as close as possible to those coming from the inference environment (e.g. match the keys).
|
||||
# Below, we match the keys in the dataset (which we defined in the data conversion script) to
|
||||
# the keys we use in our inference pipeline (defined in the inference script for libero).
|
||||
# For your own dataset, first figure out what keys your environment passes to the policy server
|
||||
# and then modify the mappings below so your dataset's keys get matched to those target keys.
|
||||
# The repack transform simply remaps key names here.
|
||||
repack_transform = _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
@@ -268,13 +287,29 @@ class LeRobotLiberoDataConfig(DataConfigFactory):
|
||||
]
|
||||
)
|
||||
|
||||
# Prepare data for policy training
|
||||
# Convert images to uint8 numpy arrays, add masks
|
||||
# The data transforms are applied to the data coming from the dataset *and* during inference.
|
||||
# Below, we define the transforms for data going into the model (``inputs``) and the transforms
|
||||
# for data coming out of the model (``outputs``) (the latter is only used during inference).
|
||||
# We defined these transforms in `libero_policy.py`. You can check the detailed comments there for
|
||||
# how to modify the transforms to match your dataset. Once you created your own transforms, you can
|
||||
# replace the transforms below with your own.
|
||||
data_transforms = _transforms.Group(
|
||||
inputs=[libero_policy.LiberoInputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
||||
outputs=[libero_policy.LiberoOutputs()],
|
||||
)
|
||||
# Use delta actions (not for gripper)
|
||||
|
||||
# One additional data transform: pi0 models are trained on delta actions (relative to the first
|
||||
# state in each action chunk). IF your data has ``absolute`` actions (e.g. target joint angles)
|
||||
# you can uncomment the following line to convert the actions to delta actions. The only exception
|
||||
# is for the gripper actions which are always absolute.
|
||||
# In the example below, we would apply the delta conversion to the first 6 actions (joints) and
|
||||
# leave the 7th action (gripper) unchanged, i.e. absolute.
|
||||
# In Libero, the raw actions in the dataset are already delta actions, so we *do not* need to
|
||||
# apply a separate delta conversion (that's why it's commented out). Choose whether to apply this
|
||||
# transform based on whether your dataset uses ``absolute`` or ``delta`` actions out of the box.
|
||||
|
||||
# TODO(karl): comment this out once we have updated the Libero checkpoints to not use
|
||||
# the delta action transform
|
||||
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
||||
data_transforms = data_transforms.push(
|
||||
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
||||
@@ -282,8 +317,10 @@ class LeRobotLiberoDataConfig(DataConfigFactory):
|
||||
)
|
||||
|
||||
# Model transforms include things like tokenizing the prompt and action targets
|
||||
# You do not need to change anything here for your own dataset.
|
||||
model_transforms = ModelTransformFactory()(model_config)
|
||||
|
||||
# We return all data transforms for training and inference. No need to change anything here.
|
||||
return dataclasses.replace(
|
||||
self.create_base_config(assets_dirs),
|
||||
repack_transforms=repack_transform,
|
||||
@@ -439,24 +476,178 @@ _CONFIGS = [
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_droid_jointpos",
|
||||
model=pi0.Pi0Config(action_horizon=10),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
||||
outputs=[_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)), droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_fast_droid_jointpos",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=10),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[
|
||||
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
|
||||
droid_policy.DroidOutputs(),
|
||||
],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_binning_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_binning_droid_jointpos",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15, max_token_len=400),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[
|
||||
_transforms.AbsoluteActions(_transforms.make_bool_mask(7, -1)),
|
||||
droid_policy.DroidOutputs(),
|
||||
],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.BinningTokenizer,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_fast_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_fast_specialist_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.FASTTokenizer,
|
||||
fast_model_tokenizer_kwargs={"fast_tokenizer_path": "KarlP/fast_droid_specialist"},
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_vq_droid",
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=8, action_horizon=15),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim, model_type=ModelType.PI0_FAST)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
model_transforms=ModelTransformFactory(
|
||||
fast_model_tokenizer=_tokenizer.FSQTokenizer,
|
||||
fast_model_tokenizer_kwargs={
|
||||
"fsq_tokenizer_path": "s3://openpi-assets-simeval/tokenizers/droid_fsq_tokenizer"
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
TrainConfig(
|
||||
name="paligemma_diffusion_droid",
|
||||
model=pi0.Pi0Config(action_horizon=10, action_dim=8),
|
||||
data=SimpleDataConfig(
|
||||
assets=AssetsConfig(asset_id="droid"),
|
||||
data_transforms=lambda model: _transforms.Group(
|
||||
inputs=[droid_policy.DroidInputs(action_dim=model.action_dim)],
|
||||
outputs=[droid_policy.DroidOutputs()],
|
||||
),
|
||||
base_config=DataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
),
|
||||
#
|
||||
# Fine-tuning Libero configs.
|
||||
#
|
||||
# These train configs define the hyperparameters for fine-tuning the base model on your own dataset.
|
||||
# They are used to define key elements like the dataset you are training on, the base checkpoint you
|
||||
# are using, and other hyperparameters like how many training steps to run or what learning rate to use.
|
||||
# For your own dataset, you can copy this class and modify the dataset name, and data transforms based on
|
||||
# the comments below.
|
||||
TrainConfig(
|
||||
# Change the name to reflect your model and dataset.
|
||||
name="pi0_libero",
|
||||
# Here you define the model config -- In this example we use pi0 as the model
|
||||
# architecture and perform *full* finetuning. in the examples below we show how to modify
|
||||
# this to perform *low-memory* (LORA) finetuning and use pi0-FAST as an alternative architecture.
|
||||
model=pi0.Pi0Config(),
|
||||
# Here you define the dataset you are training on. In this example we use the Libero
|
||||
# dataset. For your own dataset, you can change the repo_id to point to your dataset.
|
||||
# Also modify the DataConfig to use the new config you made for your dataset above.
|
||||
data=LeRobotLiberoDataConfig(
|
||||
repo_id="physical-intelligence/libero",
|
||||
base_config=DataConfig(
|
||||
local_files_only=False, # Set to True for local-only datasets.
|
||||
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
||||
# ``task`` field in the LeRobot dataset. If set to True, the prompt will show up in
|
||||
# a field called ``prompt`` in the input dict. The recommended setting is True.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Here you define which pre-trained checkpoint you want to load to initialize the model.
|
||||
# This should match the model config you chose above -- i.e. in this case we use the pi0 base model.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||
# Below you can define other hyperparameters like the learning rate, number of training steps, etc.
|
||||
# Check the base TrainConfig class for a full list of available hyperparameters.
|
||||
num_train_steps=30_000,
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_libero_low_mem_finetune",
|
||||
# Here is an example of loading a pi0 model for LoRA fine-tuning.
|
||||
model=pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"),
|
||||
data=LeRobotLiberoDataConfig(
|
||||
repo_id="physical-intelligence/libero",
|
||||
@@ -467,13 +658,28 @@ _CONFIGS = [
|
||||
),
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
# The freeze filter defines which parameters should be frozen during training.
|
||||
# We have a convenience function in the model config that returns the default freeze filter
|
||||
# for the given model config for LoRA finetuning. Just make sure it matches the model config
|
||||
# you chose above.
|
||||
freeze_filter=pi0.Pi0Config(
|
||||
paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora"
|
||||
).get_freeze_filter(),
|
||||
# Turn off EMA for LoRA finetuning.
|
||||
ema_decay=None,
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_fast_libero",
|
||||
# Here is an example of loading a pi0-FAST model for full finetuning.
|
||||
# Modify action_dim and action_horizon to match your dataset (action horizon is equal to
|
||||
# the desired action chunk length).
|
||||
# The max_token_len is the maximum number of (non-image) tokens the model can handle.
|
||||
# This includes the tokenized prompt, proprioceptive state, and (FAST-tokenized) action tokens.
|
||||
# Choosing this value too small may chop off tokens at the end of your sequence (the code will throw
|
||||
# a warning), while choosing it too large will waste memory (since we pad each batch element to the
|
||||
# max_token_len). A good rule of thumb is to use approx 180 for single-arm robots, and approx 250 for
|
||||
# two-arm robots. Generally, err on the lower side here first, and potentially increase the value if
|
||||
# you see many warnings being thrown during training.
|
||||
model=pi0_fast.Pi0FASTConfig(action_dim=7, action_horizon=10, max_token_len=180),
|
||||
data=LeRobotLiberoDataConfig(
|
||||
repo_id="physical-intelligence/libero",
|
||||
@@ -482,8 +688,33 @@ _CONFIGS = [
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Note that we load the pi0-FAST base model checkpoint here.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||
num_train_steps=30_000,
|
||||
),
|
||||
TrainConfig(
|
||||
name="pi0_fast_libero_low_mem_finetune",
|
||||
# Here is an example of loading a pi0-FAST model for LoRA finetuning.
|
||||
# For setting action_dim, action_horizon, and max_token_len, see the comments above.
|
||||
model=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
||||
),
|
||||
data=LeRobotLiberoDataConfig(
|
||||
repo_id="physical-intelligence/libero",
|
||||
base_config=DataConfig(
|
||||
local_files_only=False, # Set to True for local-only datasets.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("s3://openpi-assets/checkpoints/pi0_fast_base/params"),
|
||||
num_train_steps=30_000,
|
||||
# Again, make sure to match the model config above when extracting the freeze filter
|
||||
# that specifies which parameters should be frozen during LoRA finetuning.
|
||||
freeze_filter=pi0_fast.Pi0FASTConfig(
|
||||
action_dim=7, action_horizon=10, max_token_len=180, paligemma_variant="gemma_2b_lora"
|
||||
).get_freeze_filter(),
|
||||
# Turn off EMA for LoRA finetuning.
|
||||
ema_decay=None,
|
||||
),
|
||||
#
|
||||
# Fine-tuning Aloha configs.
|
||||
|
||||
@@ -214,7 +214,12 @@ class TorchDataLoader:
|
||||
raise ValueError(f"Local batch size ({local_batch_size}) is larger than the dataset size ({len(dataset)}).")
|
||||
|
||||
if sharding is None:
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
# Use data parallel sharding by default.
|
||||
sharding = jax.sharding.NamedSharding(
|
||||
jax.sharding.Mesh(jax.devices(), ("B",)),
|
||||
jax.sharding.PartitionSpec("B"),
|
||||
)
|
||||
|
||||
self._sharding = sharding
|
||||
self._num_batches = num_batches
|
||||
|
||||
|
||||
Reference in New Issue
Block a user