Initial commit

This commit is contained in:
Ury Zhilinsky
2024-12-23 13:38:06 -08:00
commit 385780ecc3
121 changed files with 15572 additions and 0 deletions

3
.dockerignore Normal file
View File

@@ -0,0 +1,3 @@
.venv
checkpoints
data

16
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1,16 @@
# The CODEOWNERS file defines individuals or teams that are automatically requested for
# review when someone opens a pull request that modifies certain code. When a draft pull
# request is marked as ready for review, code owners are automatically notified.
#
# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
#
# This is a comment.
# Each line is a file pattern followed by one or more owners.
# Global owners.
* @jimmyt857 @Michael-Equi @uzhilinsky
src/openpi/models/ @kvablack @uzhilinsky
src/openpi/training/ @kvablack @uzhilinsky
scripts/ @jimmyt857 @kvablack @uzhilinsky

17
.github/workflows/pre-commit.yml vendored Normal file
View File

@@ -0,0 +1,17 @@
name: pre-commit
on:
push:
branches:
- main
pull_request:
branches:
- "*"
jobs:
pre-commit:
runs-on: ubuntu-latest
env:
GIT_LFS_SKIP_SMUDGE: true
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v3
- uses: pre-commit/action@v3.0.1

26
.github/workflows/test.yml vendored Normal file
View File

@@ -0,0 +1,26 @@
name: Test
on:
pull_request:
branches:
- "*"
jobs:
run_tests:
name: Run Tests
runs-on: ubuntu-latest
env:
GIT_LFS_SKIP_SMUDGE: true
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
- name: Set up Python
run: uv python install
- name: Install the project
run: uv sync --all-extras --dev
- name: Run tests
run: uv run pytest src scripts

168
.gitignore vendored Normal file
View File

@@ -0,0 +1,168 @@
# Data directories.
assets/
checkpoints/
data/
wandb/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

9
.gitmodules vendored Normal file
View File

@@ -0,0 +1,9 @@
[submodule "third_party/aloha"]
path = third_party/aloha
url = git@github.com:Physical-Intelligence/aloha.git
[submodule "third_party/calvin"]
path = third_party/calvin
url = git@github.com:mees/calvin.git
[submodule "third_party/libero"]
path = third_party/libero
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git

16
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,16 @@
exclude: third_party/
repos:
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.9
hooks:
- id: uv-lock
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.1
hooks:
# Run the linter.
- id: ruff
args: [--fix]
- id: ruff-format

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.11

11
.vscode/settings.json vendored Normal file
View File

@@ -0,0 +1,11 @@
{
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff",
"editor.formatOnSave": true,
},
"python.testing.pytestArgs": [
"src"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

201
LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

116
README.md Normal file
View File

@@ -0,0 +1,116 @@
# openpi
openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
Currently, it is focused on the `pi0` model described in [this blog post](https://www.physicalintelligence.company/blog/pi0).
## Setup
When cloning this repo, make sure to update submodules:
```bash
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
# Or if you already cloned the repo:
git submodule update --init --recursive
```
### Using uv
We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up.
Once uv is installed, run the following to set up the environment:
```bash
GIT_LFS_SKIP_SMUDGE=1 uv sync
```
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
### Docker Setup
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/install_docker_ubuntu22.sh` and `scripts/install_nvidia_container_toolkit.sh`.
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
### Downloading checkpoints
By default checkpoints are downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
## Running Training
Training configs are defined in [src/openpi/training/config.py](src/openpi/training/config.py) and the training script is in [scripts/train.py](scripts/train.py).
Each registered config is available as a command line argument to `scripts/train.py`. To find all available command line arguments for your config, run `uv run scripts/train.py <config-name> --help`, or look at the `TrainConfig` class in [src/openpi/training/config.py](src/openpi/training/config.py).
For example, to train with the `pi0_aloha_sim` config, run the following;
(one time only) Compute the norm stats for the training data:
```bash
uv run scripts/compute_norm_stats.py --config-name pi0_aloha_sim
```
Run training:
```bash
uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
```
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
## Running examples
We provide example integrations with several robotics platforms. See the README in each example for more details:
- [ALOHA Sim](examples/aloha_sim)
- [ALOHA Real](examples/aloha_real)
- [CALVIN](examples/calvin)
- [LIBERO](examples/libero)
## Running the openpi server
The server can be configured to serve openpi policies in the following ways:
- Serve a default policy for the given environment.
- Serve a trained policy from a checkpoint.
- Serve an exported model.
### Serve the default policy for the LIBERO environment
```bash
uv run scripts/serve_policy.py --env LIBERO --default_prompt "my task"
```
### Serve a trained policy from an openpi checkpoint
This option allows serving a model that was trained using the openpi training code.
```bash
uv run scripts/serve_policy.py --env ALOHA_SIM policy:checkpoint --policy.config=pi0_aloha_sim --policy.dir=checkpoints/pi0_aloha_sim/exp_name/10000
```
The training config is used to determine which data transformations should be applied to the runtime data before feeding into the model. The norm stats, which are used to normalize the transformed data, are loaded from the checkpoint directory.
### Serve an exported model
There are also a number of checkpoints that are available as exported JAX graphs, which we trained ourselves using our internal training code. These can be served using the following command:
```bash
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model --policy.processor=trossen_biarm_single_base_cam_24dim
```
In this case, the data transformations are taken from the default policy and the processor name will be used to determine which norms stats should be used to normalize the transformed data.
### Running with Docker:
```bash
export SERVER_ARGS="--env ALOHA_SIM --default_prompt 'my task'"
docker compose -f scripts/compose.yml up --build
```

View File

@@ -0,0 +1,70 @@
# Dockerfile for the Aloha real environment.
# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cmake \
curl \
libffi-dev \
python3-rosdep \
python3-rosinstall \
python3-rosinstall-generator \
whiptail \
git \
wget \
openssh-client \
ros-noetic-cv-bridge \
ros-noetic-usb-cam \
ros-noetic-realsense2-camera \
keyboard-configuration
WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
cd /python && \
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
tar -zxvf Python-3.10.14.tgz && \
cd Python-3.10.14 && \
ls -lhR && \
./configure --enable-optimizations && \
make install && \
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
cd ~ && rm -rf /python && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app
# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]

View File

@@ -0,0 +1,73 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha).
## Prerequisites
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
## With Docker
```bash
export SERVER_ARGS="--env ALOHA --default_prompt='toast out of toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python examples/aloha_real/main.py
```
Terminal window 2:
```bash
roslaunch --wait aloha ros_nodes.launch
```
Terminal window 3:
```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='toast out of toaster'
```
## Model Guide
The Pi0 Base Model is an out-of-the-box model for general tasks. You can find more details in the [technical report](https://www.physicalintelligence.company/download/pi0.pdf).
While we strongly recommend fine-tuning the model to your own data to adapt it to particular tasks, it may be possible to prompt the model to attempt some tasks that were in the pre-training data. For example, below is a video of the model attempting the "toast out of toaster" task.
<p align="center">
<img src="https://github.com/Physical-Intelligence/openpi/blob/main/examples/aloha_real/toast.gif" alt="toast out of toaster"/>
</p>
## Training on your own Aloha dataset
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
```python
TrainConfig(
name=<your-config-name>,
data=LeRobotAlohaDataConfig(
repo_id=<your-repo-id>,
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
),
),
```
Run the training script:
```bash
uv run scripts/train.py <your-config-name>
```

View File

@@ -0,0 +1,63 @@
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
runtime:
image: aloha_real
depends_on:
- aloha_ros_nodes
- ros_master
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
aloha_ros_nodes:
image: aloha_real
depends_on:
- ros_master
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- /dev:/dev
command: roslaunch --wait aloha ros_nodes.launch
ros_master:
image: ros:noetic-robot
network_mode: host
privileged: true
command:
- roscore
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,71 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
### Task parameters
### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = (
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2

View File

@@ -0,0 +1,52 @@
import einops
import numpy as np
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from examples.aloha_real import real_env as _real_env
class AlohaRealEnvironment(_environment.Environment):
"""An environment for an Aloha robot on real hardware."""
def __init__(self, render_height: int = 480, render_width: int = 640) -> None:
self._env = _real_env.make_real_env(init_node=True)
self._render_height = render_height
self._render_width = render_width
self._ts = None
@override
def reset(self) -> None:
self._ts = self._env.reset()
@override
def done(self) -> bool:
return False
@override
def get_observation(self) -> dict:
if self._ts is None:
raise RuntimeError("Timestep is not set. Call reset() first.")
obs = self._ts.observation
for k in list(obs["images"].keys()):
if "_depth" in k:
del obs["images"][k]
images = []
for cam_name in obs["images"]:
curr_image = obs["images"][cam_name]
curr_image = einops.rearrange(curr_image, "h w c -> c h w")
images.append(curr_image)
stacked_images = np.stack(images, axis=0).astype(np.uint8)
# TODO: Consider removing these transformations.
return {
"qpos": obs["qpos"],
"image": stacked_images,
}
@override
def apply_action(self, action: dict) -> None:
self._ts = self._env.step(action["qpos"])

View File

@@ -0,0 +1,42 @@
import dataclasses
import logging
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro
from examples.aloha_real import env as _env
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
action_horizon: int = 25
def main(args: Args) -> None:
runtime = _runtime.Runtime(
environment=_env.AlohaRealEnvironment(),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
subscribers=[],
max_hz=50,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,167 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
from examples.aloha_real import constants
from examples.aloha_real import robot_utils
class RealEnv:
"""
Environment for real robot bi-manual manipulation
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
"""
def __init__(self, init_node, *, setup_robots: bool = True):
self.puppet_bot_left = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_left",
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()
self.recorder_left = robot_utils.Recorder("left", init_node=False)
self.recorder_right = robot_utils.Recorder("right", init_node=False)
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
self.gripper_command = JointSingleCommand(name="gripper")
def setup_robots(self):
robot_utils.setup_puppet_bot(self.puppet_bot_left)
robot_utils.setup_puppet_bot(self.puppet_bot_right)
def get_qpos(self):
left_qpos_raw = self.recorder_left.qpos
right_qpos_raw = self.recorder_right.qpos
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
] # this is position not joint
right_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
] # this is position not joint
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
def get_qvel(self):
left_qvel_raw = self.recorder_left.qvel
right_qvel_raw = self.recorder_right.qvel
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
def get_effort(self):
left_effort_raw = self.recorder_left.effort
right_effort_raw = self.recorder_right.effort
left_robot_effort = left_effort_raw[:7]
right_robot_effort = right_effort_raw[:7]
return np.concatenate([left_robot_effort, right_robot_effort])
def get_images(self):
return self.image_recorder.get_images()
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
self.gripper_command.cmd = left_gripper_desired_joint
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
right_gripper_desired_pos_normalized
)
self.gripper_command.cmd = right_gripper_desired_joint
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
def _reset_joints(self):
# reset_position = START_ARM_POSE[:6]
reset_position = [0, -1.5, 1.5, 0, 0, 0]
robot_utils.move_arms(
[self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1
)
def _reset_gripper(self):
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
)
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
)
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def get_reward(self):
return 0
def reset(self, *, fake=False):
if not fake:
# Reboot puppet robot gripper motors
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
self._reset_joints()
self._reset_gripper()
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def step(self, action):
state_len = int(len(action) / 2)
left_action = action[:state_len]
right_action = action[state_len:]
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
self.set_gripper_pose(left_action[-1], right_action[-1])
time.sleep(constants.DT)
return dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
# Arm actions
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
# Gripper actions
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
return action
def make_real_env(init_node, *, setup_robots: bool = True) -> RealEnv:
return RealEnv(init_node, setup_robots=setup_robots)

View File

@@ -0,0 +1,18 @@
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets

View File

@@ -0,0 +1,156 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
catkin-pkg==1.0.0
# via rospkg
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
contourpy==1.1.1
# via matplotlib
cycler==0.12.1
# via matplotlib
distro==1.9.0
# via rospkg
dm-control==1.0.23
# via -r examples/aloha_real/requirements.in
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
docutils==0.20.1
# via catkin-pkg
einops==0.8.0
# via -r examples/aloha_real/requirements.in
etils==1.3.0
# via mujoco
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
h5py==3.11.0
# via -r examples/aloha_real/requirements.in
idna==3.10
# via requests
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.7.5
# via -r examples/aloha_real/requirements.in
mdurl==0.1.2
# via markdown-it-py
modern-robotics==1.1.1
# via -r examples/aloha_real/requirements.in
msgpack==1.1.0
# via -r examples/aloha_real/requirements.in
mujoco==3.2.3
# via dm-control
numpy==1.24.4
# via
# -r examples/aloha_real/requirements.in
# contourpy
# dm-control
# dm-env
# h5py
# labmaze
# matplotlib
# modern-robotics
# mujoco
# opencv-python
# pyquaternion
# scipy
opencv-python==4.10.0.84
# via -r examples/aloha_real/requirements.in
packaging==24.2
# via
# -r examples/aloha_real/requirements.in
# matplotlib
pexpect==4.9.0
# via -r examples/aloha_real/requirements.in
pillow==10.4.0
# via
# -r examples/aloha_real/requirements.in
# matplotlib
protobuf==5.29.1
# via dm-control
ptyprocess==0.7.0
# via pexpect
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.1.4
# via
# catkin-pkg
# dm-control
# matplotlib
pyquaternion==0.9.9
# via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
# via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
# via
# catkin-pkg
# matplotlib
pyyaml==6.0.2
# via
# -r examples/aloha_real/requirements.in
# rospkg
requests==2.32.3
# via
# -r examples/aloha_real/requirements.in
# dm-control
rich==13.9.4
# via tyro
rospkg==1.5.1
# via -r examples/aloha_real/requirements.in
scipy==1.10.1
# via dm-control
setuptools==75.3.0
# via
# catkin-pkg
# dm-control
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_real/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_real/requirements.in
zipp==3.20.2
# via etils

View File

@@ -0,0 +1,275 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time
from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState
from examples.aloha_real import constants
class ImageRecorder:
def __init__(self, init_node=True, is_debug=False):
self.is_debug = is_debug
self.bridge = CvBridge()
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
if init_node:
rospy.init_node("image_recorder", anonymous=True)
for cam_name in self.camera_names:
setattr(self, f"{cam_name}_rgb_image", None)
setattr(self, f"{cam_name}_depth_image", None)
setattr(self, f"{cam_name}_timestamp", 0.0)
if cam_name == "cam_high":
callback_func = self.image_cb_cam_high
elif cam_name == "cam_low":
callback_func = self.image_cb_cam_low
elif cam_name == "cam_left_wrist":
callback_func = self.image_cb_cam_left_wrist
elif cam_name == "cam_right_wrist":
callback_func = self.image_cb_cam_right_wrist
else:
raise NotImplementedError
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
if self.is_debug:
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
time.sleep(0.5)
def image_cb(self, cam_name, data):
setattr(
self,
f"{cam_name}_rgb_image",
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
)
# setattr(
# self,
# f"{cam_name}_depth_image",
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
# )
setattr(
self,
f"{cam_name}_timestamp",
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
)
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
if self.is_debug:
getattr(self, f"{cam_name}_timestamps").append(
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
)
def image_cb_cam_high(self, data):
cam_name = "cam_high"
return self.image_cb(cam_name, data)
def image_cb_cam_low(self, data):
cam_name = "cam_low"
return self.image_cb(cam_name, data)
def image_cb_cam_left_wrist(self, data):
cam_name = "cam_left_wrist"
return self.image_cb(cam_name, data)
def image_cb_cam_right_wrist(self, data):
cam_name = "cam_right_wrist"
return self.image_cb(cam_name, data)
def get_images(self):
image_dict = {}
for cam_name in self.camera_names:
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
time.sleep(0.00001)
rgb_image = getattr(self, f"{cam_name}_rgb_image")
depth_image = getattr(self, f"{cam_name}_depth_image")
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
image_dict[cam_name] = rgb_image
image_dict[f"{cam_name}_depth"] = depth_image
return image_dict
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
for cam_name in self.camera_names:
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
print(f"{cam_name} {image_freq=:.2f}")
print()
class Recorder:
def __init__(self, side, init_node=True, is_debug=False):
self.secs = None
self.nsecs = None
self.qpos = None
self.effort = None
self.arm_command = None
self.gripper_command = None
self.is_debug = is_debug
if init_node:
rospy.init_node("recorder", anonymous=True)
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_group",
JointGroupCommand,
self.puppet_arm_commands_cb,
)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_single",
JointSingleCommand,
self.puppet_gripper_commands_cb,
)
if self.is_debug:
self.joint_timestamps = deque(maxlen=50)
self.arm_command_timestamps = deque(maxlen=50)
self.gripper_command_timestamps = deque(maxlen=50)
time.sleep(0.1)
def puppet_state_cb(self, data):
self.qpos = data.position
self.qvel = data.velocity
self.effort = data.effort
self.data = data
if self.is_debug:
self.joint_timestamps.append(time.time())
def puppet_arm_commands_cb(self, data):
self.arm_command = data.cmd
if self.is_debug:
self.arm_command_timestamps.append(time.time())
def puppet_gripper_commands_cb(self, data):
self.gripper_command = data.cmd
if self.is_debug:
self.gripper_command_timestamps.append(time.time())
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
joint_freq = 1 / dt_helper(self.joint_timestamps)
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
def get_arm_joint_positions(bot):
return bot.arm.core.joint_states.position[:6]
def get_arm_gripper_positions(bot):
return bot.gripper.core.joint_states.position[6]
def move_arms(bot_list, target_pose_list, move_time=1):
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
for t in range(num_steps):
for bot_id, bot in enumerate(bot_list):
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
time.sleep(constants.DT)
def move_grippers(bot_list, target_pose_list, move_time):
print(f"Moving grippers to {target_pose_list=}")
gripper_command = JointSingleCommand(name="gripper")
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
for t in range(num_steps):
d = {}
for bot_id, bot in enumerate(bot_list):
gripper_command.cmd = traj_list[bot_id][t]
bot.gripper.core.pub_single.publish(gripper_command)
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
f.write(json.dumps(d) + "\n")
time.sleep(constants.DT)
def setup_puppet_bot(bot):
bot.dxl.robot_reboot_motors("single", "gripper", True)
bot.dxl.robot_set_operating_modes("group", "arm", "position")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_on(bot)
def setup_master_bot(bot):
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_off(bot)
def set_standard_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def set_low_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def torque_off(bot):
bot.dxl.robot_torque_enable("group", "arm", False)
bot.dxl.robot_torque_enable("single", "gripper", False)
def torque_on(bot):
bot.dxl.robot_torque_enable("group", "arm", True)
bot.dxl.robot_torque_enable("single", "gripper", True)
# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
print("\nSyncing!")
# activate master arms
torque_on(master_bot_left)
torque_on(master_bot_right)
# get puppet arm positions
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
# get puppet gripper positions
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
# move master arms to puppet positions
move_arms(
[master_bot_left, master_bot_right],
[puppet_left_qpos, puppet_right_qpos],
move_time=1,
)
# move master grippers to puppet positions
move_grippers(
[master_bot_left, master_bot_right],
[puppet_left_gripper, puppet_right_gripper],
move_time=1,
)

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 MiB

View File

@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoDisplay(_subscriber.Subscriber):
"""Displays video frames."""
def __init__(self) -> None:
self._ax: plt.Axes | None = None
self._plt_img: plt.Image | None = None
@override
def on_episode_start(self) -> None:
plt.ion()
self._ax = plt.subplot()
self._plt_img = None
@override
def on_step(self, observation: dict, action: dict) -> None:
assert self._ax is not None
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
if self._plt_img is None:
self._plt_img = self._ax.imshow(im)
else:
self._plt_img.set_data(im)
plt.pause(0.001)
@override
def on_episode_end(self) -> None:
plt.ioff()
plt.close()

View File

@@ -0,0 +1,41 @@
# Dockerfile for the Aloha simulation environment.
# Build the container:
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev
ENV MUJOCO_GL=egl
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]

View File

@@ -0,0 +1,36 @@
# Run Aloha Sim
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/aloha_sim/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client
# Run the simulation
MUJOCO_GL=egl python examples/aloha_sim/main.py
```
Note: If you are seeing EGL errors, you may need to install the following dependencies:
```bash
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env ALOHA_SIM
```

View File

@@ -0,0 +1,39 @@
# Run with:
# docker compose -f examples/aloha_sim/compose.yml up --build
services:
runtime:
image: aloha_sim
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_sim/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

56
examples/aloha_sim/env.py Normal file
View File

@@ -0,0 +1,56 @@
import gym_aloha # noqa: F401
import gymnasium
import numpy as np
from openpi_client.runtime import environment as _environment
from typing_extensions import override
class AlohaSimEnvironment(_environment.Environment):
"""An environment for an Aloha robot in simulation."""
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
np.random.seed(seed)
self._rng = np.random.default_rng(seed)
self._gym = gymnasium.make(task, obs_type=obs_type)
self._last_obs = None
self._done = True
self._episode_reward = 0.0
@override
def reset(self) -> None:
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
@override
def done(self) -> bool:
return self._done
@override
def get_observation(self) -> dict:
if self._last_obs is None:
raise RuntimeError("Observation is not set. Call reset() first.")
return self._last_obs # type: ignore
@override
def apply_action(self, action: dict) -> None:
gym_obs, reward, terminated, truncated, info = self._gym.step(action["qpos"])
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = terminated or truncated
self._episode_reward = max(self._episode_reward, reward)
def _convert_observation(self, gym_obs: dict) -> dict:
# Convert axis order from [H, W, C] --> [C, H, W]
img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
# Add multi-camera dimension, to match the way real aloha provides images as [cam_idx, C, H, W].
imgs = np.expand_dims(img, axis=0)
return {
"qpos": gym_obs["agent_pos"],
"image": imgs,
}

View File

@@ -0,0 +1,55 @@
import dataclasses
import logging
import pathlib
import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro
@dataclasses.dataclass
class Args:
out_path: pathlib.Path = pathlib.Path("out.mp4")
task: str = "gym_aloha/AlohaTransferCube-v0"
seed: int = 0
action_horizon: int = 10
host: str = "0.0.0.0"
port: int = 8000
display: bool = False
def main(args: Args) -> None:
runtime = _runtime.Runtime(
environment=_env.AlohaSimEnvironment(
task=args.task,
seed=args.seed,
),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
subscribers=[
_saver.VideoSaver(args.out_path),
],
max_hz=50,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,8 @@
gym-aloha
imageio
matplotlib
msgpack
numpy
typing-extensions
tyro
websockets

View File

@@ -0,0 +1,132 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
cloudpickle==3.1.0
# via gymnasium
contourpy==1.3.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
farama-notifications==0.0.4
# via gymnasium
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
gym-aloha==0.1.1
# via -r examples/aloha_sim/requirements.in
gymnasium==1.0.0
# via gym-aloha
idna==3.10
# via requests
imageio==2.36.1
# via
# -r examples/aloha_sim/requirements.in
# gym-aloha
imageio-ffmpeg==0.5.1
# via imageio
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.9.3
# via -r examples/aloha_sim/requirements.in
mdurl==0.1.2
# via markdown-it-py
msgpack==1.1.0
# via -r examples/aloha_sim/requirements.in
mujoco==2.3.7
# via
# dm-control
# gym-aloha
numpy==1.26.4
# via
# -r examples/aloha_sim/requirements.in
# contourpy
# dm-control
# dm-env
# gymnasium
# imageio
# labmaze
# matplotlib
# mujoco
# scipy
packaging==24.2
# via matplotlib
pillow==11.0.0
# via
# imageio
# matplotlib
protobuf==5.29.1
# via dm-control
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.2.0
# via
# dm-control
# matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
requests==2.32.3
# via dm-control
rich==13.9.4
# via tyro
scipy==1.14.1
# via dm-control
setuptools==75.6.0
# via
# dm-control
# imageio-ffmpeg
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.1
# via tyro
typing-extensions==4.12.2
# via
# -r examples/aloha_sim/requirements.in
# gymnasium
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_sim/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_sim/requirements.in

View File

@@ -0,0 +1,35 @@
import logging
import pathlib
import imageio
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoSaver(_subscriber.Subscriber):
"""Saves episode data."""
def __init__(self, out_path: pathlib.Path, subsample: int = 1) -> None:
self._out_path = out_path
self._images: list[np.ndarray] = []
self._subsample = subsample
@override
def on_episode_start(self) -> None:
self._images = []
@override
def on_step(self, observation: dict, action: dict) -> None:
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
self._images.append(im)
@override
def on_episode_end(self) -> None:
logging.info(f"Saving video to {self._out_path}")
imageio.mimwrite(
self._out_path,
[np.asarray(x) for x in self._images[:: self._subsample]],
fps=50 // max(1, self._subsample),
)

View File

@@ -0,0 +1,65 @@
# THIS DOCKERFILE DOES NOT YET WORK
# Dockerfile for the CALVIN benchmark.
# Build the container:
# docker build . -t calvin -f examples/calvin/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app --privileged --gpus all calvin /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
make \
g++ \
git \
wget \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6 \
unzip \
ffmpeg
# Install miniconda
ENV CONDA_DIR=/opt/conda
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR
ENV PATH=$CONDA_DIR/bin:$PATH
# Submodules don't work with calvin because it internally parses git metadata.
# So we have to clone it directly.
RUN git clone --recurse-submodules https://github.com/mees/calvin.git /root/calvin
RUN conda create -n calvin python=3.8
RUN source /opt/conda/bin/activate calvin && \
pip install setuptools==57.5.0 && \
cd /root/calvin && \
./install.sh && \
pip install \
imageio[ffmpeg] \
moviepy \
numpy==1.23.0 \
tqdm \
tyro \
websockets \
msgpack
ENV PYTHONPATH=/app:/app/packages/openpi-client/src
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
RUN mkdir -p /datasets && cd /datasets && \
wget http://calvin.cs.uni-freiburg.de/dataset/calvin_debug_dataset.zip && \
unzip calvin_debug_dataset.zip && \
rm calvin_debug_dataset.zip
WORKDIR /app
CMD ["/bin/bash", "-c", "source /opt/conda/bin/activate calvin && python examples/calvin/main.py"]

47
examples/calvin/README.md Normal file
View File

@@ -0,0 +1,47 @@
# CALVIN Benchmark
This example runs the CALVIN benchmark: https://github.com/mees/calvin
## With Docker
```bash
export SERVER_ARGS="--env CALVIN"
docker compose -f examples/calvin/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
cd $OPENPI_ROOT
conda create -n calvin python=3.8
conda activate calvin
git clone --recurse-submodules https://github.com/mees/calvin.git
cd calvin
pip install setuptools==57.5.0
./install.sh
pip install imageio[ffmpeg] moviepy numpy==1.23.0 tqdm tyro websockets msgpack
ENV PYTHONPATH=$PYTHONPATH:$OPENPI_ROOT/packages/openpi-client/src
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
export CALVIN_DATASETS_DIR=~/datasets
export CALVIN_DATASET=calvin_debug_dataset
mkdir -p $CALVIN_DATASETS_DIR && cd $CALVIN_DATASETS_DIR
wget http://calvin.cs.uni-freiburg.de/dataset/$CALVIN_DATASET.zip
unzip $CALVIN_DATASET.zip
rm $CALVIN_DATASET.zip
# Run the simulation
cd $OPENPI_ROOT
python examples/calvin/main.py --args.calvin_data_path=$CALVIN_DATASETS_DIR
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env CALVIN
```

View File

@@ -0,0 +1,46 @@
# Run with:
# docker compose -f examples/calvin/compose.yml up --build
services:
runtime:
image: calvin
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/calvin/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

175
examples/calvin/main.py Normal file
View File

@@ -0,0 +1,175 @@
"""Runs a model in a CALVIN simulation environment."""
import collections
from dataclasses import dataclass
import logging
import pathlib
import time
from calvin_agent.evaluation.multistep_sequences import get_sequences
from calvin_agent.evaluation.utils import get_env_state_for_initial_condition
import calvin_env
from calvin_env.envs.play_table_env import get_env
import hydra
import imageio
import numpy as np
from omegaconf import OmegaConf
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
@dataclass
class Args:
#################################################################################################################
# Model server parameters
#################################################################################################################
host: str = "0.0.0.0"
port: int = 8000
replan_steps: int = 5
#################################################################################################################
# CALVIN environment-specific parameters
#################################################################################################################
calvin_data_path: str = "/datasets/calvin_debug_dataset" # Path to CALVIN dataset for loading validation tasks
max_subtask_steps: int = 360 # Max number of steps per subtask
num_trials: int = 1000 # Number of rollouts per task
#################################################################################################################
# Utils
#################################################################################################################
video_out_path: str = "data/calvin/videos" # Path to save videos
num_save_videos: int = 5 # Number of videos to be logged per task
video_temp_subsample: int = 5 # Temporal subsampling to make videos shorter
seed: int = 7 # Random Seed (for reproducibility)
def main(args: Args) -> None:
# Set random seed
np.random.seed(args.seed)
# Initialize CALVIN environment
env = get_env(pathlib.Path(args.calvin_data_path) / "validation", show_gui=False)
# Get CALVIN eval task set
task_definitions, task_instructions, task_reward = _get_calvin_tasks_and_reward(args.num_trials)
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
# Start evaluation.
episode_solved_subtasks = []
per_subtask_success = collections.defaultdict(list)
for i, (initial_state, task_sequence) in enumerate(tqdm.tqdm(task_definitions)):
logging.info(f"Starting episode {i+1}...")
logging.info(f"Task sequence: {task_sequence}")
# Reset env to initial position for task
robot_obs, scene_obs = get_env_state_for_initial_condition(initial_state)
env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
rollout_images = []
solved_subtasks = 0
for subtask in task_sequence:
start_info = env.get_info()
action_plan = collections.deque()
obs = env.get_obs()
done = False
for _ in range(args.max_subtask_steps):
img = obs["rgb_obs"]["rgb_static"]
wrist_img = obs["rgb_obs"]["rgb_gripper"]
rollout_images.append(img.transpose(2, 0, 1))
if not action_plan:
# Finished executing previous action chunk -- compute new chunk
# Prepare observations dict
element = {
"observation/rgb_static": img,
"observation/rgb_gripper": wrist_img,
"observation/state": obs["robot_obs"],
"prompt": str(task_instructions[subtask][0]),
}
# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
# Round gripper action since env expects gripper_action in (-1, 1)
action[-1] = 1 if action[-1] > 0 else -1
# Step environment
obs, _, _, current_info = env.step(action)
# check if current step solves a task
current_task_info = task_reward.get_task_info_for_set(start_info, current_info, {subtask})
if len(current_task_info) > 0:
done = True
solved_subtasks += 1
break
per_subtask_success[subtask].append(int(done))
if not done:
# Subtask execution failed --> stop episode
break
episode_solved_subtasks.append(solved_subtasks)
if len(episode_solved_subtasks) < args.num_save_videos:
# Save rollout video.
idx = len(episode_solved_subtasks)
imageio.mimwrite(
pathlib.Path(args.video_out_path) / f"rollout_{idx}.mp4",
[np.asarray(x) for x in rollout_images[:: args.video_temp_subsample]],
fps=50 // args.video_temp_subsample,
)
# Print current performance after each episode
logging.info(f"Solved subtasks: {solved_subtasks}")
_calvin_print_performance(episode_solved_subtasks, per_subtask_success)
# Log final performance
logging.info(f"results/avg_num_subtasks: : {np.mean(episode_solved_subtasks)}")
for i in range(1, 6):
# Compute fraction of episodes that have *at least* i successful subtasks
logging.info(
f"results/avg_success_len_{i}: {np.sum(episode_solved_subtasks >= i) / len(episode_solved_subtasks)}"
)
for key in per_subtask_success:
logging.info(f"results/avg_success__{key}: {np.mean(per_subtask_success[key])}")
def _get_calvin_tasks_and_reward(num_sequences):
conf_dir = pathlib.Path(calvin_env.__file__).absolute().parents[2] / "calvin_models" / "conf"
task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
task_oracle = hydra.utils.instantiate(task_cfg)
val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml")
eval_sequences = get_sequences(num_sequences)
return eval_sequences, val_annotations, task_oracle
def _calvin_print_performance(episode_solved_subtasks, per_subtask_success):
# Compute avg success rate per task length
logging.info("#####################################################")
logging.info(f"Avg solved subtasks: {np.mean(episode_solved_subtasks)}\n")
logging.info("Per sequence_length avg success:")
for i in range(1, 6):
# Compute fraction of episodes that have *at least* i successful subtasks
logging.info(f"{i}: {np.sum(np.array(episode_solved_subtasks) >= i) / len(episode_solved_subtasks) * 100}%")
logging.info("\n Per subtask avg success:")
for key in per_subtask_success:
logging.info(f"{key}: \t\t\t {np.mean(per_subtask_success[key]) * 100}%")
logging.info("#####################################################")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(main)

View File

@@ -0,0 +1,59 @@
# Dockerfile for the LIBERO benchmark.
# Build the container:
# docker build . -t libero -f examples/libero/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
make \
g++ \
clang \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
# Create a default config file to avoid an input prompt from LIBERO's init script.
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
ENV LIBERO_CONFIG_PATH=/tmp/libero
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
benchmark_root: /app/third_party/libero/libero/libero
bddl_files: /app/third_party/libero/libero/libero/bddl_files
init_states: /app/third_party/libero/libero/libero/init_files
datasets: /app/third_party/libero/libero/datasets
assets: /app/third_party/libero/libero/libero/assets
EOF
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]

39
examples/libero/README.md Normal file
View File

@@ -0,0 +1,39 @@
# LIBERO Benchmark
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
## With Docker
```bash
# Grant access to the X11 server:
sudo xhost +local:docker
export SERVER_ARGS="--env LIBERO"
docker compose -f examples/libero/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.8 examples/libero/.venv
source examples/libero/.venv/bin/activate
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
uv pip install -e packages/openpi-client
uv pip install -e third_party/libero
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
# Run the simulation
python examples/libero/main.py
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env LIBERO
```

View File

@@ -0,0 +1,49 @@
# Run with:
# docker compose -f examples/libero/compose.yml up --build
services:
runtime:
image: libero
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/libero/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
- /tmp/.X11-unix:/tmp/.X11-unix:ro
environment:
- DISPLAY=$DISPLAY
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

215
examples/libero/main.py Normal file
View File

@@ -0,0 +1,215 @@
import collections
import dataclasses
import logging
import math
import pathlib
import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
@dataclasses.dataclass
class Args:
#################################################################################################################
# Model server parameters
#################################################################################################################
host: str = "0.0.0.0"
port: int = 8000
resize_size: int = 224
replan_steps: int = 5
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = (
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
)
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
num_trials_per_task: int = 50 # Number of rollouts per task
#################################################################################################################
# Utils
#################################################################################################################
video_out_path: str = "data/libero/videos" # Path to save videos
seed: int = 7 # Random Seed (for reproducibility)
def eval_libero(args: Args) -> None:
# Set random seed
np.random.seed(args.seed)
# Initialize LIBERO task suite
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[args.task_suite_name]()
num_tasks_in_suite = task_suite.n_tasks
logging.info(f"Task suite: {args.task_suite_name}")
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
if args.task_suite_name == "libero_spatial":
max_steps = 220 # longest training demo has 193 steps
elif args.task_suite_name == "libero_object":
max_steps = 280 # longest training demo has 254 steps
elif args.task_suite_name == "libero_goal":
max_steps = 300 # longest training demo has 270 steps
elif args.task_suite_name == "libero_10":
max_steps = 520 # longest training demo has 505 steps
elif args.task_suite_name == "libero_90":
max_steps = 400 # longest training demo has 373 steps
else:
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
# Start evaluation
total_episodes, total_successes = 0, 0
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
# Get task
task = task_suite.get_task(task_id)
# Get default LIBERO initial states
initial_states = task_suite.get_task_init_states(task_id)
# Initialize LIBERO environment and task description
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
# Start episodes
task_episodes, task_successes = 0, 0
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
logging.info(f"\nTask: {task_description}")
# Reset environment
env.reset()
action_plan = collections.deque()
# Set initial states
obs = env.set_init_state(initial_states[episode_idx])
# Setup
t = 0
replay_images = []
logging.info(f"Starting episode {task_episodes+1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
# and we need to wait for them to fall
if t < args.num_steps_wait:
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
t += 1
continue
# Get preprocessed image
# IMPORTANT: rotate 180 degrees to match train preprocessing
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
img = image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
wrist_img = image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
# Save preprocessed image for replay video
replay_images.append(img)
if not action_plan:
# Finished executing previous action chunk -- compute new chunk
# Prepare observations dict
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": str(task_description),
}
# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
# Execute action in environment
obs, reward, done, info = env.step(action.tolist())
if done:
task_successes += 1
total_successes += 1
break
t += 1
except Exception as e:
logging.error(f"Caught exception: {e}")
break
task_episodes += 1
total_episodes += 1
# Save a replay video of the episode
suffix = "success" if done else "failure"
task_segment = task_description.replace(" ", "_")
imageio.mimwrite(
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
[np.asarray(x) for x in replay_images],
fps=10,
)
# Log current results
logging.info(f"Success: {done}")
logging.info(f"# episodes completed so far: {total_episodes}")
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
# Log final results
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total episodes: {total_episodes}")
def _get_libero_env(task, resolution, seed):
"""Initializes and returns the LIBERO environment, along with the task description."""
task_description = task.language
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
env = OffScreenRenderEnv(**env_args)
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
return env, task_description
def _quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(eval_libero)

View File

@@ -0,0 +1,11 @@
imageio[ffmpeg]
numpy==1.22.4
tqdm
tyro
PyYaml
opencv-python==4.6.0.66
torch==1.11.0+cu113
torchvision==0.12.0+cu113
torchaudio==0.11.0+cu113
robosuite==1.4.1
matplotlib==3.5.3

View File

@@ -0,0 +1,136 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
absl-py==2.1.0
# via mujoco
certifi==2024.12.14
# via requests
charset-normalizer==3.4.0
# via requests
cycler==0.12.1
# via matplotlib
docstring-parser==0.16
# via tyro
etils==1.3.0
# via mujoco
eval-type-backport==0.2.0
# via tyro
evdev==1.7.1
# via pynput
fonttools==4.55.3
# via matplotlib
glfw==1.12.0
# via mujoco
idna==3.10
# via requests
imageio==2.35.1
# via -r examples/libero/requirements.in
imageio-ffmpeg==0.5.1
# via imageio
importlib-metadata==8.5.0
# via typeguard
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
llvmlite==0.36.0
# via numba
markdown-it-py==3.0.0
# via rich
matplotlib==3.5.3
# via -r examples/libero/requirements.in
mdurl==0.1.2
# via markdown-it-py
mujoco==3.2.3
# via robosuite
numba==0.53.1
# via robosuite
numpy==1.22.4
# via
# -r examples/libero/requirements.in
# imageio
# matplotlib
# mujoco
# numba
# opencv-python
# robosuite
# scipy
# torchvision
opencv-python==4.6.0.66
# via
# -r examples/libero/requirements.in
# robosuite
packaging==24.2
# via matplotlib
pillow==10.4.0
# via
# imageio
# matplotlib
# robosuite
# torchvision
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pynput==1.7.7
# via robosuite
pyopengl==3.1.7
# via mujoco
pyparsing==3.1.4
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
python-xlib==0.33
# via pynput
pyyaml==6.0.2
# via -r examples/libero/requirements.in
requests==2.32.3
# via torchvision
rich==13.9.4
# via tyro
robosuite==1.4.1
# via -r examples/libero/requirements.in
scipy==1.10.1
# via robosuite
setuptools==75.3.0
# via
# imageio-ffmpeg
# numba
shtab==1.7.1
# via tyro
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
termcolor==2.4.0
# via robosuite
torch==1.11.0+cu113
# via
# -r examples/libero/requirements.in
# torchaudio
# torchvision
torchaudio==0.11.0+cu113
# via -r examples/libero/requirements.in
torchvision==0.12.0+cu113
# via -r examples/libero/requirements.in
tqdm==4.67.1
# via -r examples/libero/requirements.in
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# torch
# torchvision
# typeguard
# tyro
tyro==0.9.2
# via -r examples/libero/requirements.in
urllib3==2.2.3
# via requests
zipp==3.20.2
# via
# etils
# importlib-metadata
# importlib-resources

View File

@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"import numpy as np\n",
"\n",
"record_path = pathlib.Path(\"../policy_records\")\n",
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
"\n",
"records = []\n",
"for i in range(num_steps):\n",
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
" records.append(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"length of records\", len(records))\n",
"print(\"keys in records\", records[0].keys())\n",
"\n",
"for k in records[0]:\n",
" print(f\"{k} shape: {records[0][k].shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"def get_image(step: int, idx: int = 0):\n",
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
" return img[idx].transpose(1, 2, 0)\n",
"\n",
"\n",
"def show_image(step: int, idx_lst: list[int]):\n",
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
" return Image.fromarray(np.hstack(imgs))\n",
"\n",
"\n",
"for i in range(2):\n",
" display(show_image(i, [0]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def get_axis(name, axis):\n",
" return np.array([record[name][axis] for record in records])\n",
"\n",
"\n",
"# qpos is [..., 14] of type float:\n",
"# 0-5: left arm joint angles\n",
"# 6: left arm gripper\n",
"# 7-12: right arm joint angles\n",
"# 13: right arm gripper\n",
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
"\n",
"\n",
"def make_data():\n",
" cur_dim = 0\n",
" in_data = {}\n",
" out_data = {}\n",
" for name, dim_size in names:\n",
" for i in range(dim_size):\n",
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
" cur_dim += 1\n",
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
"\n",
"\n",
"in_data, out_data = make_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in in_data.columns:\n",
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
" data.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,32 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/simple_client/main.py"]

View File

@@ -0,0 +1,24 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
## With Docker
```bash
export SERVER_ARGS="--example aloha"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py
```

View File

@@ -0,0 +1,37 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,81 @@
import dataclasses
import logging
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import tyro
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
example: str = "droid"
def main(args: Args) -> None:
obs_fn = {
"aloha": _random_observation_aloha,
"droid": _random_observation_droid,
"calvin": _random_observation_calvin,
"libero": _random_observation_libero,
}[args.example]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
# Send 1 observation to make sure the model is loaded.
policy.infer(obs_fn())
start = time.time()
for _ in range(100):
policy.infer(obs_fn())
end = time.time()
print(f"Total time taken: {end - start}")
# Note that each inference returns many action chunks.
print(f"Inference rate: {100 / (end - start)} Hz")
def _random_observation_aloha() -> dict:
return {
"qpos": np.ones((14,)),
"image": np.random.rand(4, 3, 480, 640).astype(np.float32),
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_calvin() -> dict:
return {
"observation/state": np.random.rand(15),
"observation/rgb_static": np.random.rand(4, 3, 480, 640).astype(np.float32),
"observation/rgb_gripper": np.random.rand(4, 3, 480, 640).astype(np.float32),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.rand(4, 3, 480, 640).astype(np.float32),
"observation/wrist_image": np.random.rand(4, 3, 480, 640).astype(np.float32),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(main)

View File

@@ -0,0 +1,2 @@
numpy
tyro

View File

@@ -0,0 +1,27 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
backports-cached-property==1.0.2
# via tyro
docstring-parser==0.16
# via tyro
eval-type-backport==0.1.3
# via tyro
markdown-it-py==2.2.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.21.6
# via -r examples/simple_client/requirements.in
pygments==2.17.2
# via rich
rich==13.8.1
# via tyro
shtab==1.7.1
# via tyro
typing-extensions==4.7.1
# via
# markdown-it-py
# rich
# tyro
tyro==0.9.1
# via -r examples/simple_client/requirements.in

View File

@@ -0,0 +1,25 @@
[project]
name = "openpi-client"
version = "0.1.0"
requires-python = ">=3.7"
dependencies = [
"dm-tree>=0.1.8",
"msgpack>=1.0.5",
"numpy>=1.21.6",
"pillow>=9.0.0",
"tree>=0.2.4",
"websockets>=11.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
dev-dependencies = [
"pytest>=8.3.4",
]
[tool.ruff]
line-length = 120
target-version = "py37"

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@@ -0,0 +1,39 @@
from typing import Dict
import numpy as np
import tree
from typing_extensions import override
from openpi_client import base_policy as _base_policy
class ActionChunkBroker(_base_policy.BasePolicy):
"""Wraps a policy to return action chunks one-at-a-time.
Assumes that the first dimension of all action fields is the chunk size.
A new inference call to the inner policy is only made when the current
list of chunks is exhausted.
"""
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
self._last_results: Dict[str, np.ndarray] | None = None
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
self._cur_step += 1
if self._cur_step >= self._action_horizon:
self._last_results = None
return results

View File

@@ -0,0 +1,8 @@
import abc
from typing import Dict
class BasePolicy(abc.ABC):
@abc.abstractmethod
def infer(self, obs: Dict) -> Dict:
"""Infer actions from observations."""

View File

@@ -0,0 +1,48 @@
import numpy as np
from PIL import Image
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
Args:
images: A batch of images in [..., height, width, channel] format.
height: The target height of the image.
width: The target width of the image.
method: The interpolation method to use. Default is bilinear.
Returns:
The resized images in [..., height, width, channel].
"""
# If the images are already the correct size, return them as is.
if images.shape[-3:-1] == (height, width):
return images
original_shape = images.shape
images = images.reshape(-1, *original_shape[-3:])
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
width without distortion by padding with zeros.
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
"""
cur_width, cur_height = image.size
if cur_width == width and cur_height == height:
return image # No need to resize if the image is already the correct size.
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_image = image.resize((resized_width, resized_height), resample=method)
zero_image = Image.new(resized_image.mode, (width, height), 0)
pad_height = max(0, int((height - resized_height) / 2))
pad_width = max(0, int((width - resized_width) / 2))
zero_image.paste(resized_image, (pad_width, pad_height))
assert zero_image.size == (width, height)
return zero_image

View File

@@ -0,0 +1,37 @@
import numpy as np
import openpi_client.image_tools as image_tools
def test_resize_with_pad_shapes():
# Test case 1: Resize image with larger dimensions
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
height = 20
width = 20
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (2, height, width, 3)
assert np.all(resized_images == 0)
# Test case 2: Resize image with smaller dimensions
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
height = 15
width = 15
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (3, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with the same dimensions
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
height = 50
width = 50
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with odd-numbered padding
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
height = 60
width = 80
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)

View File

@@ -0,0 +1,57 @@
"""Adds NumPy array support to msgpack.
msgpack is good for (de)serializing data over a network for multiple reasons:
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
- msgpack is widely used and has good cross-language support
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
languages like Python and JavaScript
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
than pickle for serializing large arrays using the below strategy
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
that it falls back to pickle for object arrays.
"""
import functools
import msgpack
import numpy as np
def pack_array(obj):
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported dtype: {obj.dtype}")
if isinstance(obj, np.ndarray):
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {
b"__npgeneric__": True,
b"data": obj.item(),
b"dtype": obj.dtype.str,
}
return obj
def unpack_array(obj):
if b"__ndarray__" in obj:
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)

View File

@@ -0,0 +1,45 @@
import numpy as np
import pytest
import tree
from openpi_client import msgpack_numpy
def _check(expected, actual):
if isinstance(expected, np.ndarray):
assert expected.shape == actual.shape
assert expected.dtype == actual.dtype
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
else:
assert expected == actual
@pytest.mark.parametrize(
"data",
[
1, # int
1.0, # float
"hello", # string
np.bool_(True), # boolean scalar
np.array([1, 2, 3])[0], # int scalar
np.str_("asdf"), # string scalar
[1, 2, 3], # list
{"key": "value"}, # dict
{"key": [1, 2, 3]}, # nested dict
np.array(1.0), # 0D array
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
np.array(["asdf", "qwer"]), # string array
np.array([True, False]), # boolean array
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
np.array([np.nan, np.inf, -np.inf]), # special float values
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
[np.array([1, 2]), np.array([3, 4])], # list of arrays
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
],
)
def test_pack_unpack(data):
packed = msgpack_numpy.packb(data)
unpacked = msgpack_numpy.unpackb(packed)
tree.map_structure(_check, data, unpacked)

View File

@@ -0,0 +1,13 @@
import abc
class Agent(abc.ABC):
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
Agents receive observations about the state of the world, and return actions
to take in response.
"""
@abc.abstractmethod
def get_action(self, observation: dict) -> dict:
"""Query the agent for the next action."""

View File

@@ -0,0 +1,15 @@
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
from typing_extensions import override
# TODO: Consider unifying policies and agents.
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)

View File

@@ -0,0 +1,32 @@
import abc
class Environment(abc.ABC):
"""An Environment represents the robot and the environment it inhabits.
The primary contract of environments is that they can be queried for observations
about their state, and have actions applied to them to change that state.
"""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the environment to its initial state.
This will be called once before starting each episode.
"""
@abc.abstractmethod
def done(self) -> bool:
"""Allow the environment to signal that the task is done.
This will be called after each step. It should return `True` if the task is
done (either successfully or unsuccessfully), and `False` otherwise.
"""
@abc.abstractmethod
def get_observation(self) -> dict:
"""Query the environment for the current state."""
@abc.abstractmethod
def apply_action(self, action: dict) -> None:
"""Take an action in the environment."""

View File

@@ -0,0 +1,78 @@
import logging
import threading
import time
from openpi_client.runtime import agent as _agent
from openpi_client.runtime import environment as _environment
from openpi_client.runtime import subscriber as _subscriber
class Runtime:
"""The core module orchestrating interactions between key components of the system."""
def __init__(
self,
environment: _environment.Environment,
agent: _agent.Agent,
subscribers: list[_subscriber.Subscriber],
max_hz: float = 0,
) -> None:
self._environment = environment
self._agent = agent
self._subscribers = subscribers
self._max_hz = max_hz
self._running = False
def run(self) -> None:
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
self._loop()
def run_in_new_thread(self) -> threading.Thread:
"""Runs the runtime loop in a new thread."""
thread = threading.Thread(target=self.run)
thread.start()
return thread
def stop(self) -> None:
"""Stops the runtime loop."""
self._running = False
def _loop(self) -> None:
"""The runtime loop."""
logging.info("Starting episode...")
self._environment.reset()
for subscriber in self._subscribers:
subscriber.on_episode_start()
self._running = True
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
last_step_time = time.time()
while self._running:
self._step()
# Sleep to maintain the desired frame rate
now = time.time()
dt = now - last_step_time
if dt < step_time:
time.sleep(step_time - dt)
last_step_time = time.time()
else:
last_step_time = now
logging.info("Episode completed.")
for subscriber in self._subscribers:
subscriber.on_episode_end()
def _step(self) -> None:
"""A single step of the runtime loop."""
observation = self._environment.get_observation()
action = self._agent.get_action(observation)
self._environment.apply_action(action)
for subscriber in self._subscribers:
subscriber.on_step(observation, action)
if self._environment.done():
self.stop()

View File

@@ -0,0 +1,20 @@
import abc
class Subscriber(abc.ABC):
"""Subscribes to events in the runtime.
Subscribers can be used to save data, visualize, etc.
"""
@abc.abstractmethod
def on_episode_start(self) -> None:
"""Called when an episode starts."""
@abc.abstractmethod
def on_step(self, observation: dict, action: dict) -> None:
"""Append a step to the episode."""
@abc.abstractmethod
def on_episode_end(self) -> None:
"""Called when an episode ends."""

View File

@@ -0,0 +1,40 @@
import logging
import time
from typing import Dict
from typing_extensions import override
import websockets.sync.client
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
class WebsocketClientPolicy(_base_policy.BasePolicy):
"""Implements the Policy interface by communicating with a server over websocket.
See WebsocketPolicyServer for a corresponding server implementation.
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
self._uri = f"ws://{host}:{port}"
self._packer = msgpack_numpy.Packer()
self._ws = self._wait_for_server()
def _wait_for_server(self) -> websockets.sync.client.ClientConnection:
logging.info(f"Waiting for server at {self._uri}...")
while True:
try:
return websockets.sync.client.connect(self._uri, compression=None, max_size=None)
except ConnectionRefusedError:
logging.info("Still waiting for server...")
time.sleep(5)
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
data = self._packer.pack(obs)
self._ws.send(data)
response = self._ws.recv()
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
raise RuntimeError(f"Error in inference server:\n{response}")
return msgpack_numpy.unpackb(response)

123
pyproject.toml Normal file
View File

@@ -0,0 +1,123 @@
[project]
name = "openpi"
version = "0.1.0"
description = "Physical Intelligence open source repo"
readme = "README.md"
requires-python = ">=3.11"
license = { file = "LICENSE" }
dependencies = [
"augmax>=0.3.4",
"dm-tree>=0.1.8",
"einops>=0.8.0",
"equinox>=0.11.8",
"flatbuffers>=24.3.25",
"flax==0.10.2",
"fsspec[gcs]>=2024.6.0",
"gym-aloha>=0.1.1",
"imageio>=2.36.1",
"jax[cuda12]==0.4.36",
"jaxtyping==0.2.36",
"lerobot",
"ml_collections==1.0.0",
"numpy>=1.26.4",
"numpydantic>=1.6.6",
"opencv-python>=4.10.0.84",
"openpi-client",
"orbax-checkpoint==0.10.2",
"pillow>=11.0.0",
"ruff>=0.7.1",
"s3fs>=2024.9.0",
"sentencepiece>=0.2.0",
"torch>=2.5.1",
"tqdm-loggable>=0.2",
"typing-extensions>=4.12.2",
"tyro>=0.9.4",
"wandb>=0.19.1",
"boto3>=1.35.7",
"types-boto3[boto3,s3]>=1.35.7",
"filelock>=3.16.1",
"beartype>=0.19.0",
]
[project.urls]
Repository = "https://github.com/Physical-Intelligence/openpi"
[dependency-groups]
dev = [
"pytest>=8.3.4",
"ruff>=0.8.3",
"pre-commit>=4.0.1",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
"matplotlib>=3.10.0",
]
[tool.uv.sources]
openpi-client = { workspace = true }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "66f87365988cb5424435ea03b428426b4ede98cb" }
[tool.uv.workspace]
members = ["packages/*"]
[tool.ruff]
line-length = 120
target-version = "py311"
extend-exclude = ["docker", "third_party"]
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"C4",
"DTZ",
"E4",
"E7",
"E9",
"F",
"FBT",
"FURB",
"I",
"ICN",
"ISC",
"LOG",
"N",
"PD",
"PERF",
"PIE",
"PLC",
"PLE",
"PLR1",
"PLR5",
"PLW",
"PT",
"PTH",
"Q",
"RET",
"RUF",
"SIM",
"SLF",
"T10",
"T20",
"UP",
"W",
]
ignore = [
"F722", # Conflicts with array typing.
"T201", # We use print statements.
"PD008", # Lots of false positives.
]
unfixable = [
"B905", # Fix defaults to strict=False, which is not what we want.
]
[tool.ruff.lint.isort]
force-single-line = true
force-sort-within-sections = true
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
known-third-party = ["wandb"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

0
scripts/__init__.py Normal file
View File

479
scripts/aloha_hd5.py Normal file
View File

@@ -0,0 +1,479 @@
# ruff: noqa
"""
Script courtesy of Raziel90 https://github.com/huggingface/lerobot/pull/586/files
Example usage
python scripts/aloha_hd5.py --raw-path ~/data/ --dataset-repo-id <hf-username>/<dataset-name> --robot-type <aloha-stationary|aloha-mobile> --fps 50 --video-encoding=false --push=false
The data will be saved locally the value of the LEROBOT_HOME environment variable. By default this is set to ~/.cache/huggingface/lerobot
If you wish to submit the dataset to the hub, you can do so by setting up the hf cli https://huggingface.co/docs/huggingface_hub/en/guides/cli and setting --push=true
"""
import argparse
import logging
import os
from pathlib import Path
import shutil
import traceback
import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import torch
class AlohaHD5Extractor:
TAGS = ["aloha", "robotics", "hdf5"]
aloha_stationary = "aloha-stationary"
aloha_mobile = "aloha-mobile"
@staticmethod
def get_cameras(hdf5_data: h5py.File):
"""
Extracts the list of RGB camera keys from the given HDF5 data.
Parameters
----------
hdf5_data : h5py.File
The HDF5 file object containing the dataset.
Returns
-------
list of str
A list of keys corresponding to RGB cameras in the dataset.
"""
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
return rgb_cameras
@staticmethod
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
"""
Check the format of the given list of HDF5 files.
Parameters
----------
episode_list : list of str or list of Path
List of paths to the HDF5 files to be checked.
image_compressed : bool, optional
Flag indicating whether the images are compressed (default is True).
Raises
------
ValueError
If the episode_list is empty.
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
If the number of frames in '/action' and '/observations/qpos' keys do not match.
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
If uncompressed images do not have the expected (h, w, c) format.
"""
if not episode_list:
raise ValueError("No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'")
for episode_path in episode_list:
with h5py.File(episode_path, "r") as data:
if not all(key in data for key in ["/action", "/observations/qpos"]):
raise ValueError(
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
)
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
raise ValueError("The '/action' and '/observations/qpos' keys should have both 2 dimensions.")
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
raise ValueError(
"The '/action' and '/observations/qpos' keys should have the same number of frames."
)
for camera in AlohaHD5Extractor.get_cameras(data):
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
raise ValueError(
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
)
expected_dims = 2 if image_compressed else 4
if data[f"/observations/images/{camera}"].ndim != expected_dims:
raise ValueError(
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
)
if not image_compressed:
b, h, w, c = data[f"/observations/images/{camera}"].shape
if not c < h and c < w:
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
@staticmethod
def extract_episode_frames(
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
) -> list[dict[str, torch.Tensor]]:
"""
Extract frames from an episode stored in an HDF5 file.
Parameters
----------
episode_path : str or Path
Path to the HDF5 file containing the episode data.
features : dict of str to dict
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
image_compressed : bool
Flag indicating whether the images are stored in a compressed format.
Returns
-------
list of dict of str to torch.Tensor
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
"""
frames = []
with h5py.File(episode_path, "r") as file:
for frame_idx in range(file["/action"].shape[0]):
frame = {}
for feature_id in features:
feature_name_hd5 = feature_id.replace(".", "/")
if "images" in feature_id.split("."):
image = (
(file[feature_name_hd5][frame_idx])
if not image_compressed
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
)
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
else:
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
frames.append(frame)
return frames
@staticmethod
def define_features(
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
) -> dict[str, dict]:
"""
Define features from an HDF5 file.
Parameters
----------
hdf5_file_path : Path
The path to the HDF5 file.
image_compressed : bool, optional
Whether the images are compressed, by default True.
encode_as_video : bool, optional
Whether to encode images as video or as images, by default True.
Returns
-------
dict[str, dict]
A dictionary where keys are topic names and values are dictionaries
containing feature information such as dtype, shape, and names.
"""
# Initialize lists to store topics and features
topics = []
features = {}
# Open the HDF5 file
with h5py.File(hdf5_file_path, "r") as hdf5_file:
# Collect all dataset names in the HDF5 file
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
# Iterate over each topic to define its features
for topic in topics:
# If the topic is an image, define it as a video feature
if "images" in topic.split("/"):
sample = hdf5_file[topic][0]
features[topic.replace("/", ".")] = {
"dtype": "video" if encode_as_video else "image",
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
if image_compressed
else sample.shape,
"names": [
"channel",
"height",
"width",
],
}
# Skip compressed length topics
elif "compress_len" in topic.split("/"):
continue
# Otherwise, define it as a regular feature
else:
features[topic.replace("/", ".")] = {
"dtype": str(hdf5_file[topic][0].dtype),
"shape": (topic_shape := hdf5_file[topic][0].shape),
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
}
# Return the defined features
return features
class DatasetConverter:
"""
A class to convert datasets to Lerobot format.
Parameters
----------
raw_path : Path or str
The path to the raw dataset.
dataset_repo_id : str
The repository ID where the dataset will be stored.
fps : int
Frames per second for the dataset.
robot_type : str, optional
The type of robot, by default "".
encode_as_videos : bool, optional
Whether to encode images as videos, by default True.
image_compressed : bool, optional
Whether the images are compressed, by default True.
image_writer_processes : int, optional
Number of processes for writing images, by default 0.
image_writer_threads : int, optional
Number of threads for writing images, by default 0.
Methods
-------
extract_episode(episode_path, task_description='')
Extracts frames from a single episode and saves it with a description.
extract_episodes(episode_description='')
Extracts frames from all episodes and saves them with a description.
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
Pushes the dataset to the Hugging Face Hub.
init_lerobot_dataset()
Initializes the Lerobot dataset.
"""
def __init__(
self,
raw_path: Path | str,
dataset_repo_id: str,
fps: int,
robot_type: str = "",
encode_as_videos: bool = True,
image_compressed: bool = True,
image_writer_processes: int = 0,
image_writer_threads: int = 0,
):
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
self.dataset_repo_id = dataset_repo_id
self.fps = fps
self.robot_type = robot_type
self.image_compressed = image_compressed
self.image_writer_threads = image_writer_threads
self.image_writer_processes = image_writer_processes
self.encode_as_videos = encode_as_videos
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.setLevel(logging.INFO)
# Add console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
console_handler.setFormatter(formatter)
self.logger.addHandler(console_handler)
self.logger.info(f"{'-'*10} Aloha HD5 -> Lerobot Converter {'-'*10}")
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
self.logger.info(f"FPS: {self.fps}")
self.logger.info(f"Robot type: {self.robot_type}")
self.logger.info(f"Image compressed: {self.image_compressed}")
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
self.logger.info(f"#writer processes: {self.image_writer_processes}")
self.logger.info(f"#writer threads: {self.image_writer_threads}")
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
self.features = AlohaHD5Extractor.define_features(
self.episode_list[0],
image_compressed=self.image_compressed,
encode_as_video=self.encode_as_videos,
)
def extract_episode(self, episode_path, task_description: str = ""):
"""
Extracts frames from an episode and saves them to the dataset.
Parameters
----------
episode_path : str
The path to the episode file.
task_description : str, optional
A description of the task associated with the episode (default is an empty string).
Returns
-------
None
"""
for frame in AlohaHD5Extractor.extract_episode_frames(episode_path, self.features, self.image_compressed):
self.dataset.add_frame(frame)
self.logger.info(f"Saving Episode with Description: {task_description} ...")
self.dataset.save_episode(task=task_description)
def extract_episodes(self, episode_description: str = ""):
"""
Extracts episodes from the episode list and processes them.
Parameters
----------
episode_description : str, optional
A description of the task to be passed to the extract_episode method (default is '').
Raises
------
Exception
If an error occurs during the processing of an episode, it will be caught and printed.
Notes
-----
After processing all episodes, the dataset is consolidated.
"""
for episode_path in self.episode_list:
try:
self.extract_episode(episode_path, task_description=episode_description)
except Exception as e:
print(f"Error processing episode {episode_path}", f"{e}")
traceback.print_exc()
continue
self.dataset.consolidate()
def push_dataset_to_hub(
self,
dataset_tags: list[str] | None = None,
private: bool = False,
push_videos: bool = True,
license: str | None = "apache-2.0",
):
"""
Pushes the dataset to the Hugging Face Hub.
Parameters
----------
dataset_tags : list of str, optional
A list of tags to associate with the dataset on the Hub. Default is None.
private : bool, optional
If True, the dataset will be private. Default is False.
push_videos : bool, optional
If True, videos will be pushed along with the dataset. Default is True.
license : str, optional
The license under which the dataset is released. Default is "apache-2.0".
Returns
-------
None
"""
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
self.dataset.push_to_hub(
tags=dataset_tags,
license=license,
push_videos=push_videos,
private=private,
)
def init_lerobot_dataset(self):
"""
Initializes the LeRobot dataset.
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
Returns
-------
LeRobotDataset
The initialized LeRobot dataset.
"""
# Clean the cache if the dataset already exists
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
self.dataset = LeRobotDataset.create(
repo_id=self.dataset_repo_id,
fps=self.fps,
robot_type=self.robot_type,
features=self.features,
image_writer_threads=self.image_writer_threads,
image_writer_processes=self.image_writer_processes,
)
return self.dataset
def str2bool(value):
if isinstance(value, bool):
return value
value = value.lower()
if value in ("yes", "true", "t", "y", "1"):
return True
if value in ("no", "false", "f", "n", "0"):
return False
raise argparse.ArgumentTypeError("Boolean value expected.")
def main():
"""
Convert Aloha HD5 dataset and push to Hugging Face hub.
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
and optionally uploads the dataset to the Hugging Face hub.
Parameters
----------
--raw-path : Path
Directory containing the raw HDF5 files.
--dataset-repo-id : str
Repository ID where the dataset will be stored.
--fps : int
Frames per second for the dataset.
--robot-type : str, optional
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
--private : bool, optional
Set to True to make the dataset private. Default is False.
--push-videos : bool, optional
Set to True to push videos to the hub. Default is True.
--license : str, optional
License for the dataset. Default is "apache-2.0".
--image-compressed : bool, optional
Set to True if the images are compressed. Default is True.
--video-encoding : bool, optional
Set to True to encode images as videos. Default is True.
--nproc : int, optional
Number of image writer processes. Default is 10.
--nthreads : int, optional
Number of image writer threads. Default is 5.
"""
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
parser.add_argument("--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files.")
parser.add_argument(
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
)
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
parser.add_argument(
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
)
parser.add_argument(
"--robot-type",
type=str,
choices=["aloha-stationary", "aloha-mobile"],
default="aloha-stationary",
help="Type of robot.",
)
parser.add_argument("--private", type=str2bool, default=False, help="Set to True to make the dataset private.")
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
parser.add_argument(
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
)
parser.add_argument("--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos.")
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
args = parser.parse_args()
print(
args.video_encoding,
"-------------------------------------------------------------------------------------------------------",
)
converter = DatasetConverter(
raw_path=args.raw_path,
dataset_repo_id=args.dataset_repo_id,
fps=args.fps,
robot_type=args.robot_type,
image_compressed=args.image_compressed,
encode_as_videos=args.video_encoding,
image_writer_processes=args.nproc,
image_writer_threads=args.nthreads,
)
converter.init_lerobot_dataset()
converter.extract_episodes(episode_description=args.description)
if args.push:
converter.push_dataset_to_hub(
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
)
if __name__ == "__main__":
main()

29
scripts/compose.yml Normal file
View File

@@ -0,0 +1,29 @@
# Run with:
# docker compose -f scripts/compose.yml up --build
services:
openpi_server:
image: openpi_server
build:
context: ..
dockerfile: scripts/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
# Populate configured openpi data home to /openpi_assets inside the container.
# Populate aws credential inside the container.
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
- ~/.aws/:/root/.aws/
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,67 @@
"""Compute normalization statistics for a config.
This script is used to compute the normalization statistics for a given config. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config metadata directory.
"""
import numpy as np
import tqdm
import tyro
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
def create_dataset(config: _config.TrainConfig) -> tuple[str, _data_loader.Dataset]:
model = config.create_model()
data_config = config.data.create(config.metadata_dir, model)
if data_config.repo_id is None:
raise ValueError("Data config must have a repo_id")
dataset = _data_loader.TransformedDataset(
_data_loader.create_dataset(data_config, model),
[
*data_config.repack_transforms.inputs,
*data_config.data_transforms.inputs,
],
)
return data_config.repo_id, dataset
def main(config_name: str, max_frames: int | None = None):
config = _config.get_config(config_name)
repo_id, dataset = create_dataset(config)
num_frames = len(dataset)
shuffle = False
if max_frames is not None and max_frames < num_frames:
num_frames = max_frames
shuffle = True
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=1,
num_workers=8,
shuffle=shuffle,
num_batches=num_frames,
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
for key in keys:
values = np.asarray(batch[key][0])
stats[key].update(values.reshape(-1, values.shape[-1]))
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
output_path = config.metadata_dir / repo_id
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Add Docker's official GPG key:
sudo apt-get update
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
# See https://docs.docker.com/engine/install/linux-postinstall/
username=$(whoami)
sudo usermod -aG docker $username
# Configure docker to start automatically on system boot.
sudo systemctl enable docker.service
sudo systemctl enable containerd.service
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
if [ ~/.docker/config.json ]; then
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
fi
echo ""
echo "********************************************************************"
echo "**** Restart to allow Docker permission changes to take effect. ****"
echo "********************************************************************"
echo ""

View File

@@ -0,0 +1,17 @@
#!/bin/bash
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker

View File

@@ -0,0 +1,34 @@
# Dockerfile for serving a PI policy.
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
# Build the container:
# docker build . -t openpi_server -f scripts/serve_policy.Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Needed because LeRobot uses git-lfs.
RUN apt-get update && apt-get install -y git git-lfs
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Install the project's dependencies using the lockfile and settings
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"

243
scripts/serve_policy.py Normal file
View File

@@ -0,0 +1,243 @@
from collections.abc import Sequence
import dataclasses
import enum
import logging
from typing import Any
import tyro
from openpi import transforms
from openpi.models import exported as _exported
from openpi.models import model as _model
from openpi.policies import aloha_policy
from openpi.policies import calvin_policy
from openpi.policies import droid_policy
from openpi.policies import libero_policy
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
CALVIN = "calvin"
LIBERO = "libero"
@dataclasses.dataclass
class Exported:
"""Load an exported checkpoint."""
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
dir: str
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
processor: str | None = None
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Training config name (e.g., "pi0_aloha_sim").
config: str
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
dir: str
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy script."""
# Environment to serve the policy for.
env: EnvMode = EnvMode.ALOHA_SIM
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint | Exported | None = None
# If provided, overrides the default prompt for the policy.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
def repack_from_env(env: EnvMode) -> transforms.Group:
"""Creates environment specific repack transforms."""
# TODO(ury): Move this to the runtime.
match env:
case EnvMode.ALOHA:
return transforms.Group(
inputs=[aloha_policy.ActInputsRepack()],
outputs=[aloha_policy.ActOutputsRepack()],
)
case EnvMode.ALOHA_SIM:
return transforms.Group(
inputs=[aloha_policy.ActInputsRepack()],
outputs=[aloha_policy.ActOutputsRepack()],
)
case _:
return transforms.Group()
# Default exported models.
DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
EnvMode.ALOHA: Exported(
dir="s3://openpi-assets/exported/pi0_aloha/model",
processor="trossen_biarm_single_base_cam_24dim",
),
EnvMode.ALOHA_SIM: Exported(
dir="s3://openpi-assets/exported/pi0_aloha_sim/model",
processor="huggingface_aloha_sim_transfer_cube",
),
EnvMode.DROID: Exported(
dir="s3://openpi-assets/exported/pi0_droid/model",
processor="openx_droid",
),
EnvMode.CALVIN: Exported(
dir="s3://openpi-assets/exported/pi0_calvin/model",
processor="calvin",
),
EnvMode.LIBERO: Exported(
dir="s3://openpi-assets/exported/pi0_libero/model",
processor="libero",
),
}
def create_default_policy(
env: EnvMode, *, default_prompt: str | None = None, exported: Exported | None = None
) -> _policy.Policy:
model: _model.BaseModel
config: _policy_config.PolicyConfig
default_exported = DEFAULT_EXPORTED[env]
if exported:
checkpoint_dir = exported.dir
processor = exported.processor or default_exported.processor
else:
checkpoint_dir = default_exported.dir
processor = default_exported.processor
assert processor, "Default processor must be always set"
logging.info("Loading model...")
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
def make_policy_config(
input_layers: Sequence[transforms.DataTransformFn],
output_layers: Sequence[transforms.DataTransformFn],
sample_kwargs: dict[str, Any] | None = None,
):
sample_kwargs = sample_kwargs or {"num_steps": 10}
return _policy_config.PolicyConfig(
model=model,
norm_stats=model.norm_stats(processor),
default_prompt=default_prompt,
input_layers=input_layers,
output_layers=output_layers,
sample_kwargs=sample_kwargs,
)
logging.info("Creating policy...")
match env:
case EnvMode.ALOHA:
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),
aloha_policy.AlohaInputs(
action_dim=model.action_dim,
delta_action_mask=delta_action_mask,
adapt_to_pi=True,
),
],
output_layers=[
aloha_policy.AlohaOutputs(
delta_action_mask=delta_action_mask,
adapt_to_pi=True,
),
aloha_policy.ActOutputsRepack(),
],
)
case EnvMode.ALOHA_SIM:
config = make_policy_config(
input_layers=[
aloha_policy.ActInputsRepack(),
aloha_policy.AlohaInputs(action_dim=model.action_dim),
],
output_layers=[
aloha_policy.AlohaOutputs(),
aloha_policy.ActOutputsRepack(),
],
)
case EnvMode.DROID:
config = make_policy_config(
input_layers=[
droid_policy.DroidInputs(action_dim=model.action_dim),
],
output_layers=[
droid_policy.DroidOutputs(),
transforms.SubsampleActions(stride=5),
],
)
case EnvMode.CALVIN:
config = make_policy_config(
input_layers=[
calvin_policy.CalvinInputs(action_dim=model.action_dim),
],
output_layers=[
calvin_policy.CalvinOutputs(),
],
)
case EnvMode.LIBERO:
config = make_policy_config(
input_layers=[
libero_policy.LiberoInputs(action_dim=model.action_dim),
],
output_layers=[
libero_policy.LiberoOutputs(),
],
)
case _:
raise ValueError(f"Unknown environment mode: {env}")
return _policy_config.create_policy(config)
def create_policy(args: Args) -> _policy.Policy:
match args.policy:
case Checkpoint():
return _policy_config.create_trained_policy(
_config.get_config(args.policy.config),
args.policy.dir,
repack_transforms=repack_from_env(args.env),
default_prompt=args.default_prompt,
)
case Exported():
return create_default_policy(args.env, default_prompt=args.default_prompt, exported=args.policy)
case None:
return create_default_policy(args.env, default_prompt=args.default_prompt)
def main(args: Args) -> None:
policy = create_policy(args)
# Record the policy's behavior.
if args.record:
policy = _policy.PolicyRecorder(policy, "policy_records")
logging.info("Creating server...")
server = websocket_policy_server.WebsocketPolicyServer(policy=policy, host="0.0.0.0", port=args.port)
logging.info("Serving...")
server.serve_forever()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
main(tyro.cli(Args))

284
scripts/train.py Normal file
View File

@@ -0,0 +1,284 @@
import dataclasses
from functools import partial
import logging
import platform
from typing import Any
import etils.epath as epath
from flax.training import common_utils
import jax
import jax._src.tree_util as private_tree_util
import jax.experimental
import jax.numpy as jnp
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.common as _common
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(weight_loader: _weight_loaders.WeightLoader, params: at.Params) -> at.Params:
"""Runs the weight loader and validates that the params structure, shapes, and dtypes are unchanged."""
new_params = weight_loader.load(jax.tree.map(lambda x: x, params))
if errors := list(private_tree_util.equality_errors(params, new_params)):
raise ValueError(
"Weight loading changed the params structure:\n"
+ (
"\n".join(
f" - {jax.tree_util.keystr(path)} changed from {thing1} to {thing2}, so {explanation}.\n"
for path, thing1, thing2, explanation in errors
)
)
)
def check(kp, x, y):
if (x := jax.ShapeDtypeStruct(x.shape, x.dtype)) != (y := jax.ShapeDtypeStruct(y.shape, y.dtype)):
raise ValueError(
f"Weight loading changed the params structure: expected {y}, got {x} at {jax.tree_util.keystr(kp)}"
)
jax.tree_util.tree_map_with_path(check, params, new_params)
return new_params
@at.typecheck
def init_train_state(
config: _config.TrainConfig,
model: _model.Model,
init_rng: at.KeyArrayLike,
batch: tuple[_common.Observation, _common.Actions],
mesh: jax.sharding.Mesh,
data_sharding: jax.sharding.Sharding,
*,
resume: bool,
) -> tuple[training_utils.TrainState, Any]:
weight_decay_mask = None
freeze_mask = None
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask, freeze_mask)
def init(
rng: at.KeyArrayLike,
data: tuple[_common.Observation, _common.Actions],
params_sharding: jax.sharding.Sharding | None = None,
) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
observation, actions = data
params = model.init_params(model_rng, observation, actions)
# jax.experimental.io_callback raises spmd partitioning warnings, setting constraints
# to replicate params to avoid the warnings. the returned train state will be sharded still
# since fsdp sharding is specified as output_sharding when jitting this function.
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
params = jax.experimental.io_callback(
partial(_load_weights_and_validate, config.weight_loader),
params,
params,
ordered=True,
)
if params_sharding is not None:
params = jax.lax.with_sharding_constraint(params, params_sharding)
return training_utils.TrainState(
step=0,
params=params,
opt_state=tx.init(params),
tx=tx,
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
train_state_shape = jax.eval_shape(init, init_rng, batch)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
train_state = jax.jit(
init,
in_shardings=(replicated_sharding, data_sharding),
out_shardings=state_sharding,
static_argnums=(2,),
)(init_rng, batch, replicated_sharding)
return train_state, state_sharding
@at.typecheck
def train_step(
rng: at.KeyArrayLike,
state: training_utils.TrainState,
model: _model.Model,
batch: tuple[_common.Observation, _common.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
def loss_fn(params: at.Params, rng: at.KeyArrayLike, observation: _common.Observation, actions: _common.Actions):
chunked_loss = model.compute_loss(rng, observation, actions, params=params, train=True)
return jnp.mean(chunked_loss)
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
loss, grads = jax.value_and_grad(loss_fn)(state.params, train_rng, observation, actions)
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
new_params = optax.apply_updates(state.params, updates)
new_state = state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = new_state.replace(
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
)
)
kernel_mask = training_utils.mask_from_regex(r".*\['kernel'\]", state.params)
kernel_params = jax.tree.map(lambda p, m: p if m else None, state.params, kernel_mask)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads), # TODO: do not compute norm for frozen params
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
if jax.device_count() % config.fsdp_devices != 0:
raise ValueError(
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {config.fsdp_devices}."
)
mesh_shape = (jax.device_count() // config.fsdp_devices, config.fsdp_devices)
mesh = jax.make_mesh(mesh_shape, ("batch", "model"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("batch", "model")))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_interval=config.keep_interval,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
model = config.create_model()
data_loader = _data_loader.create_data_loader(
config,
model,
sharding=data_sharding,
num_workers=config.num_workers,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
train_state, train_state_sharding = init_train_state(
config, model, init_rng, batch, mesh, data_sharding, resume=resuming
)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
train_step,
in_shardings=(replicated_sharding, train_state_sharding, None, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
for step in pbar:
train_state, info = ptrain_step(train_rng, train_state, model, batch)
infos.append(info)
if step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())

27
scripts/train_test.py Normal file
View File

@@ -0,0 +1,27 @@
import dataclasses
import pathlib
import pytest
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=tmp_path / "checkpoint",
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)

0
src/openpi/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,77 @@
import abc
import dataclasses
from typing import TypeAlias
from flax import struct
import flax.linen as nn
import numpy as np
from openpi.shared import array_typing as at
@at.typecheck
@struct.dataclass
class Observation:
"""Holds observations, i.e., inputs to the model."""
# Images, in [-1, 1] float32.
images: dict[str, at.Float[at.Array, "*b h w c"]]
# Image masks, with same keys as images.
image_masks: dict[str, at.Bool[at.Array, "*b"]]
# Low-dimensional robot state.
state: at.Float[at.Array, "*b s"]
# Tokenized prompt.
tokenized_prompt: at.Int[at.Array, "*b l"] | None = None
# Tokenized prompt mask.
tokenized_prompt_mask: at.Int[at.Array, "*b l"] | None = None
@classmethod
def from_dict(cls, data: at.PyTree[at.ArrayLike]) -> "Observation":
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
# If images are uint8, convert them to [-1, 1] float32.
for key in data["image"]:
if data["image"][key].dtype == np.uint8:
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
return cls(
images=data["image"],
image_masks=data["image_mask"],
state=data["state"],
tokenized_prompt=data.get("tokenized_prompt"),
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
)
def to_dict(self) -> at.PyTree[at.ArrayLike]:
"""Convert the Observation to a nested dict."""
result = dataclasses.asdict(self)
# TODO(ury): This is awkward. Adjust the names to be the same.
result["image"] = result.pop("images")
result["image_mask"] = result.pop("image_masks")
return result
Actions: TypeAlias = at.Float[at.ArrayLike, "*b ah ad"]
class BaseModule(nn.Module, abc.ABC):
@at.typecheck
@abc.abstractmethod
def compute_loss(
self,
obs: Observation,
target_actions: Actions,
*,
timestep: at.Float[at.Array, " b"] | None = None,
) -> at.Float[at.Array, "b ah"]: ...
@at.typecheck
@abc.abstractmethod
def sample_actions(
self,
action_horizon: int,
action_dim: int,
obs: Observation,
**sample_kwargs,
) -> Actions: ...

View File

@@ -0,0 +1,292 @@
"""Functionality to handle internal pi checkpoints.
Used to test internal pi checkpoints and provides utilities to convert them to openpi checkpoints.
"""
import pathlib
from typing import Any
import flax.serialization
import flax.struct as struct
import jax
import jax.export
import jax.numpy as jnp
import orbax.checkpoint as ocp
from typing_extensions import override
from openpi.models import common
from openpi.models import model as _model
from openpi.shared import image_tools
from openpi.shared import normalize as _normalize
import openpi.shared.array_typing as at
import openpi.shared.download as download
def convert_to_openpi(
ckpt_dir: pathlib.Path | str, processor: str, out_dir: pathlib.Path | str, param_path: str = "decoder"
) -> None:
"""Convert a monopi checkpoint to an openpi checkpoint."""
out_dir = pathlib.Path(out_dir)
if out_dir.exists():
raise FileExistsError(f"Output directory already exists: {out_dir}")
out_dir.mkdir(parents=True, exist_ok=True)
# Load params and norm stats.
ckpt_dir = download.maybe_download(str(ckpt_dir))
sharding = jax.sharding.SingleDeviceSharding(jax.devices("cpu")[0])
params = _load_params(ckpt_dir, sharding=sharding)
norm_stats = _import_norm_stats(ckpt_dir, processor)
for part in param_path.split("/"):
if part not in params:
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
params = params[part]
# Load the monopi model.
# Save params.
ckpt = ocp.StandardCheckpointer()
ckpt.save(out_dir / "params", {"params": params})
ckpt.wait_until_finished()
# Save norm stats.
_normalize.save(out_dir / "assets", norm_stats)
@struct.dataclass
class PiModel(_model.BaseModel):
"""A model loaded from a monopi checkpoint model directory."""
params: at.Params
exported: jax.export.Exported = struct.field(pytree_node=False)
example_spec: Any = struct.field(pytree_node=False)
sample_spec: Any = struct.field(pytree_node=False)
ckpt_dir: pathlib.Path = struct.field(pytree_node=False)
@classmethod
def from_checkpoint(cls, ckpt_dir: pathlib.Path | str) -> "PiModel":
"""Load a model from a monopi model checkpoint directory. Must point at the "model" sub-directory."""
ckpt_dir = download.maybe_download(str(ckpt_dir))
with (ckpt_dir / "graph").open("rb") as f:
exported = jax.export.deserialize(f.read())
input_spec = jax.tree.unflatten(exported.in_tree, exported.in_avals)[0]
params = _load_params(ckpt_dir, input_spec[0])
example_spec = input_spec[2]
sample_spec = input_spec[3]
# Extract the action properties from the output spec.
output_spec = jax.tree.unflatten(exported.out_tree, exported.out_avals)
actions_spec = output_spec["actions"]
action_horizon, action_dim = actions_spec.shape
max_token_len = example_spec["prompt_tokens"].shape[-1]
return cls(
params=params,
exported=exported,
example_spec=example_spec,
sample_spec=sample_spec,
ckpt_dir=ckpt_dir,
action_horizon=action_horizon,
action_dim=action_dim,
max_token_len=max_token_len,
)
@jax.jit
@override
def sample_actions(self, rng: at.KeyArrayLike, observation: common.Observation, **sample_kwargs) -> common.Actions:
if observation.state.ndim == 2 and observation.state.shape[0] != 1:
raise ValueError("Only batch_size=1 is supported.")
# Convert to the example format.
example = _obs_to_example(observation, self.example_spec)
example = _unbatch(example)
# Resize the input images if needed.
def resize_if_needed(key, image):
target_shape = self.example_spec["image"][key].shape
if len(target_shape) == 3 and image.shape != target_shape:
return image_tools.resize_with_pad(image, *target_shape[-3:-1])
return image
example["image"] = {key: resize_if_needed(key, value) for key, value in example["image"].items()}
if set(sample_kwargs) != set(self.sample_spec):
raise ValueError(
f"Sample args {list(sample_kwargs)} do not match the expected args {list(self.sample_spec)}"
)
rng_data = jax.random.key_data(rng)
result = self.exported.call(self.params, rng_data, example, sample_kwargs)
return _make_batch(result)["actions"]
@override
def compute_loss(
self,
rng: at.KeyArrayLike,
observation: common.Observation,
actions: common.Actions,
*,
train: bool = False,
params: at.Params | None = None,
) -> at.Float[at.Array, "*b ah"]:
raise NotImplementedError("Not implemented.")
def fake_obs(self) -> common.Observation:
example = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), self.example_spec)
return _example_to_obs(_make_batch(example))
def norm_stats(self, processor_name: str) -> dict[str, _normalize.NormStats]:
return _import_norm_stats(self.ckpt_dir, processor_name)
def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model:
"""Creates a new model that uses the same parameters but a different module.
Args:
module: The module to use for the model.
param_path: Location of the parameter sub-tree that should be loaded (e.g., decoder).
Can include "/" to support nesting.
Returns:
A new model with the parameters loaded from the checkpoint.
"""
params = self.params
for part in param_path.split("/"):
if part not in params:
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
params = params[part]
return _model.Model(
module=module,
params=params,
action_dim=self.action_dim,
action_horizon=self.action_horizon,
max_token_len=self.max_token_len,
)
def _load_params(
path: pathlib.Path, params_spec: at.PyTree | None = None, sharding: jax.sharding.Sharding | None = None
):
if sharding is None:
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
def to_restore_args(tree):
return jax.tree.map(lambda x: ocp.ArrayRestoreArgs(dtype=x.dtype, sharding=sharding), tree)
with ocp.PyTreeCheckpointer() as ckptr:
if params_spec is None:
params_spec = ckptr.metadata(path)["params"]
item = {"params": params_spec}
return ckptr.restore(
path,
args=ocp.args.PyTreeRestore(
item=item,
restore_args=to_restore_args(item),
# This is needed to read a partial checkpoint.
transforms={},
),
)["params"]
def _obs_to_example(obs: common.Observation, example_spec: dict) -> dict:
def to_uint8(v):
return (255.0 * (v + 1.0) / 2.0).astype(jnp.uint8)
images = {k: to_uint8(v) for k, v in obs.images.items()}
image_masks = {f"{k}_mask": v for k, v in obs.image_masks.items()}
result = {
"image": {**images, **image_masks},
"state": obs.state,
"prompt_tokens": obs.tokenized_prompt,
}
# NOTE(ury): This is used to support the new version with DCT co-training.
if "mask_prompt_input" in example_spec:
allow_action_diffusion_attention = example_spec["allow_action_diffusion_attention"]
mask_ar = example_spec["mask_ar"]
result = {
**result,
"mask_prompt_input": obs.tokenized_prompt_mask,
# NOTE(ury): These values are likely wrong. Put something for now
# to make sure that the model doesn't crash.
"allow_action_diffusion_attention": _make_batch(
jnp.zeros(allow_action_diffusion_attention.shape, allow_action_diffusion_attention.dtype)
),
"mask_ar": _make_batch(jnp.ones(mask_ar.shape, mask_ar.dtype)),
}
else:
result = {
**result,
"mask_input": obs.tokenized_prompt_mask,
}
return result
def _example_to_obs(example: dict) -> common.Observation:
images, image_masks = {}, {}
for k, v in example["image"].items():
if k.endswith("_mask"):
image_masks[k.removesuffix("_mask")] = v
else:
images[k] = v
# NOTE(ury): This is used to support the new version with DCT co-training.
if "mask_prompt_input" in example:
example["mask_input"] = example["mask_prompt_input"]
return common.Observation.from_dict(
{
"image": images,
"image_mask": image_masks,
"state": example["state"],
"tokenized_prompt": example["prompt_tokens"],
"tokenized_prompt_mask": example["mask_input"],
}
)
def _import_norm_stats(ckpt_dir: pathlib.Path | str, processor_name: str) -> dict[str, _normalize.NormStats]:
ckpt_dir = pathlib.Path(ckpt_dir).resolve()
path = ckpt_dir / "processors" / processor_name
if not path.exists():
raise FileNotFoundError(f"Processor {processor_name} not found in {ckpt_dir}")
if not (found_files := list(path.glob("*/norm_stats.msgpack"))):
raise FileNotFoundError(f"norm_stats.msgpack not found in {path}")
outputs = []
for file in sorted(found_files):
with file.open("rb") as f:
norm_stats = flax.serialization.msgpack_restore(f.read())
# This is the new Normalize processor.
if "input_norms" in norm_stats:
actions = norm_stats["output_norms"]["actions"]
outputs.append(_normalize.NormStats(mean=actions["mean"], std=actions["std"]))
state = norm_stats["input_norms"]["state"]
outputs.append(_normalize.NormStats(mean=state["mean"], std=state["std"]))
# This is to support the old NormalizeActions / NormalizeState processor combo.
else:
outputs.append(_normalize.NormStats(mean=norm_stats["mean"], std=norm_stats["std"]))
return {
"actions": outputs[0],
"state": outputs[1],
}
def _make_batch(data: at.PyTree) -> at.PyTree:
return jax.tree.map(lambda x: x[jnp.newaxis, ...], data)
def _unbatch(data: at.PyTree) -> at.PyTree:
return jax.tree.map(lambda x: x[0, ...], data)

View File

@@ -0,0 +1,47 @@
import pathlib
import jax
import jax.numpy as jnp
import openpi.models.exported as exported
import openpi.models.model as _model
import openpi.models.pi0 as pi0
import openpi.training.checkpoints as _checkpoints
def test_sample_actions():
model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
actions = model.sample_actions(jax.random.key(0), model.fake_obs(), num_steps=10)
assert actions.shape == (1, model.action_horizon, model.action_dim)
def test_exported_as_pi0():
pi_model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
model = pi_model.set_module(pi0.Module(), param_path="decoder")
key = jax.random.key(0)
obs = model.fake_obs()
pi_actions = pi_model.sample_actions(key, obs, num_steps=10)
actions = model.sample_actions(key, obs, num_steps=10)
assert pi_actions.shape == (1, model.action_horizon, model.action_dim)
assert actions.shape == (1, model.action_horizon, model.action_dim)
diff = jnp.max(jnp.abs(pi_actions - actions))
assert diff < 10.0
def test_convert_to_openpi(tmp_path: pathlib.Path):
output_dir = tmp_path / "output"
exported.convert_to_openpi(
"s3://openpi-assets/exported/pi0_aloha_sim/model",
"huggingface_aloha_sim_transfer_cube",
output_dir,
)
# Make sure that we can load the params and norm stats.
_ = _model.restore_params(output_dir / "params")
_ = _checkpoints.load_norm_stats(output_dir / "assets")

600
src/openpi/models/gemma.py Normal file
View File

@@ -0,0 +1,600 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""gemma adaptation for Pi, taken from big_vision.
We follow this einsum axis naming convention:
B: batch
T: query length
S: k/v length
N: num query heads
K: num k/v heads
G: num query heads per k/v head
H: head dim
D: d_model ("features")
"""
from collections.abc import Callable, Sequence
import dataclasses
import logging
import math
from typing import Literal
import einops
import flax.linen as nn
import flax.traverse_util as traverse_util
import jax
import jax.numpy as jnp
import openpi.shared.array_typing as at
PALIGEMMA_VOCAB_SIZE = 257_152
@dataclasses.dataclass
class LoRAConfig:
rank: int
alpha: float
dropout: float = 0.0
# https://arxiv.org/pdf/2312.03732
rslora: bool = False
rank_annotation: str = "L"
def __post_init__(self):
if self.rank != int(self.alpha):
logging.warning(
"Rank and alpha are not the same, this will cause accuracy error when merging LoRA params currently."
)
@dataclasses.dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
projection_lora: LoRAConfig | None = None
projection_kv_lora: LoRAConfig | None = None
output_lora: LoRAConfig | None = None
Variant = Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"]
def get_config(variant: Variant) -> Config:
"""Returns config for specified gemma variant."""
if variant == "dummy":
return Config(
width=64,
depth=4,
mlp_dim=128,
num_heads=8,
num_kv_heads=1,
head_dim=16,
)
if variant == "gemma_300m":
# 311M params
return Config(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b_lora":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
projection_lora=LoRAConfig(rank=64, alpha=64.0),
projection_kv_lora=LoRAConfig(rank=64, alpha=64.0),
output_lora=LoRAConfig(rank=64, alpha=64.0),
)
raise ValueError(f"Unknown variant: {variant}")
@at.typecheck
class Einsum(nn.Module):
shape: tuple[int, ...]
init_fn: nn.initializers.Initializer
@nn.compact
def __call__(self, eqn, x):
dtype = x.dtype # original dtype, could be half-precision
w = self.param("w", self.init_fn, self.shape).astype(dtype)
return jnp.einsum(eqn, x, w)
_LORA_A_KEY = "lora_a"
_LORA_B_KEY = "lora_b"
@at.typecheck
class LoRAEinsum(nn.Module):
base: Einsum
lora_config: LoRAConfig
merge_eqn: str
lora_a_init_fn: nn.initializers.Initializer
lora_b_init_fn: nn.initializers.Initializer
def setup(self):
nn.share_scope(self, self.base)
@nn.compact
def __call__(self, eqn, x, *, deterministic=True):
orig_x = x
eqn_lora_a, eqn_lora_b = self._get_lora_eqn(eqn, self.merge_eqn)
if self.lora_config.dropout > 0.0:
x = nn.Dropout(rate=self.lora_config.dropout, deterministic=deterministic)(x)
lora_a_shape, lora_b_shape = self._parse_shape(self.merge_eqn)
lora_a = self.param(_LORA_A_KEY, self.lora_a_init_fn, lora_a_shape).astype(x.dtype)
lora_b = self.param(_LORA_B_KEY, self.lora_b_init_fn, lora_b_shape).astype(x.dtype)
lora_a = jnp.einsum(eqn_lora_a, x, lora_a)
lora_b = jnp.einsum(eqn_lora_b, lora_a, lora_b)
# TODO: scaling_value should ideally be a self.variable however currently base model doesn't allow any
# auxilary variables.
scaling_value = (
self.lora_config.alpha / self.lora_config.rank
if not self.lora_config.rslora
else self.lora_config.alpha / math.sqrt(self.lora_config.rank)
)
return self.base(eqn, orig_x) + lora_b * scaling_value
def _get_lora_eqn(self, eqn: str, lora_merge_eqn: str) -> tuple[str, str]:
"""Figure out lora_a and lora_b eqn from eqn and lora_merge_eqn.
input:
eqn: x,w->y
lora_merge_eqn: lora_a,lora_b->w
output:
lora_a_eqn: x,lora_a->?
lora_b_eqn: ?,lora_b->y
"""
(x_repr, w_repr), y_repr = _parse_einops_eqn(eqn)
(lora_a_repr, lora_b_repr), w_repr_p = _parse_einops_eqn(lora_merge_eqn)
assert len(w_repr) == len(self.base.shape), f"w_repr={w_repr}, shape={self.base.shape}"
assert w_repr == w_repr_p, f"w_repr={w_repr}, w_repr_p={w_repr_p} should be the same."
# figure out x,lora_a's output annotation by using y and lora_b
# the way to do this is to:
# 1. remove the common prefix and suffix from lora_b and y
# 2. then the ? will be (common prefix) (stripped y) (stripped lora_b)
# the equation will look like:
# [(prefix) (stripped y) (lora b)], [(prefix) (lora b) (suffix)] -> [(prefix) (y) (suffix)]
prefix, _, y_repr_stripped, lora_b_repr_stripped = self._remove_common_prefix_suffix(y_repr, lora_b_repr)
lora_intermediate_repr = prefix + y_repr_stripped + lora_b_repr_stripped
eqn_lora_a_lhs = ", ".join([x_repr, lora_a_repr])
eqn_lora_b_lhs = ", ".join([lora_intermediate_repr, lora_b_repr])
return eqn_lora_a_lhs + " -> " + lora_intermediate_repr, eqn_lora_b_lhs + " -> " + y_repr
def _remove_common_prefix_suffix(self, str1, str2):
# Get the common prefix
prefix = ""
for i in range(min(len(str1), len(str2))):
if str1[i] == str2[i]:
prefix += str1[i]
else:
break
# Get the common suffix
suffix = ""
for i in range(1, min(len(str1), len(str2)) + 1):
if str1[-i] == str2[-i]:
suffix = str1[-i] + suffix
else:
break
return prefix, suffix, str1[len(prefix) : -len(suffix)], str2[len(prefix) : -len(suffix)]
def _parse_shape(self, lora_merge_eqn: str) -> tuple[tuple[int, ...], tuple[int, ...]]:
(lora_lhs_part_0, lora_lhs_part_1), lora_rhs = _parse_einops_eqn(lora_merge_eqn)
ann_to_dim = dict(zip(lora_rhs, self.base.shape, strict=True))
ann_to_dim[self.lora_config.rank_annotation] = self.lora_config.rank
return tuple(ann_to_dim[ann] for ann in lora_lhs_part_0), tuple(ann_to_dim[ann] for ann in lora_lhs_part_1)
def merge_lora_params(lora_params: at.PyTree, get_lora_transform_eqn: Callable[[str], str]) -> at.PyTree:
params = lora_params["params"]
flattened_params = traverse_util.flatten_dict(params, sep="/")
merged_params = {}
for k in flattened_params:
if _LORA_A_KEY not in k:
continue
lora_b_key = k.replace(_LORA_A_KEY, _LORA_B_KEY)
orig_w_key = k.replace(_LORA_A_KEY, "w")
assert lora_b_key in flattened_params
assert orig_w_key in flattened_params
lora_merge = jnp.einsum(get_lora_transform_eqn(k), flattened_params[k], flattened_params[lora_b_key])
# TODO: Currently we don't handling lora scaling value here due to the base model doesn't support auxilary
# variables.
merged_params[orig_w_key] = flattened_params[orig_w_key] + lora_merge
for k in flattened_params:
if _LORA_A_KEY in k or _LORA_B_KEY in k:
continue
if k not in merged_params:
merged_params[k] = flattened_params[k]
return {"params": traverse_util.unflatten_dict(merged_params, sep="/")}
def _parse_einops_eqn(eqn: str) -> tuple[tuple[str, str], str]:
lhs, rhs = eqn.split("->")
lhs_parts = lhs.split(",")
assert len(lhs_parts) == 2
def strip_space(s):
return s.replace(" ", "")
lhs_parts[0] = strip_space(lhs_parts[0])
lhs_parts[1] = strip_space(lhs_parts[1])
rhs = strip_space(rhs)
return ((lhs_parts[0], lhs_parts[1]), rhs)
@at.typecheck
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
normed_inputs = normed_inputs * (
1 + scale
) # scale by learned parameter in float32 (matches Flax implementation)
return normed_inputs.astype(dtype) # return in original dtype
@at.typecheck
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
"input_embedding",
nn.initializers.normal(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x):
x = self.input_embedding_table[(x,)]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x):
return jnp.dot(x, self.input_embedding_table.T)
@at.typecheck
class Attention(nn.Module):
"""Attention module."""
configs: Sequence[Config]
@nn.compact
def __call__(self, xs, positions, attn_mask, decode: bool): # noqa: FBT001
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
qkvs = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is None:
continue
if config.num_kv_heads == config.num_heads:
qkv_einsum = Einsum(
shape=(3, config.num_heads, config.width, config.head_dim),
name=_name("qkv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
)
if config.projection_lora is not None:
qkv_einsum = LoRAEinsum(
qkv_einsum,
config.projection_lora,
merge_eqn="3KDL,3KLKH->3KDH",
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
)
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
else:
q_einsum = Einsum(
shape=(config.num_heads, config.width, config.head_dim),
name=_name("q_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
)
if config.projection_lora is not None:
q_einsum = LoRAEinsum(
q_einsum,
config.projection_lora,
merge_eqn="NDL,NLNH->NDH",
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 2)),
)
q = q_einsum("BTD,NDH->BTNH", x)
kv_einsum = Einsum(
shape=(2, config.num_kv_heads, config.width, config.head_dim),
name=_name("kv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
)
if config.projection_kv_lora is not None:
kv_einsum = LoRAEinsum(
kv_einsum,
config.projection_kv_lora,
merge_eqn="2KDL,2KLKH->2KDH",
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
)
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
qkvs.append((q, k, v))
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
q = _apply_rope(q, positions=positions)
q *= self.configs[0].head_dim ** -0.5
k = _apply_rope(k, positions=positions)
# should still be half-precision here (if input was half-precision)
assert q.dtype == k.dtype == v.dtype == dtype
if decode:
if not self.has_variable("cache", "k_cache"):
# initial prefill
self.put_variable("cache", "k_cache", k)
self.put_variable("cache", "v_cache", v)
else:
# decoding
k = jnp.concatenate([self.get_variable("cache", "k_cache"), k], axis=1)
v = jnp.concatenate([self.get_variable("cache", "v_cache"), v], axis=1)
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
raise ValueError(
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
)
# big_neg = jnp.finfo(logits.dtype).min
big_neg = -2.3819763e38 # See gemma/modules.py
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
out = []
start = 0
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
end = start + x.shape[1]
out_einsum = Einsum(
shape=(config.num_heads, config.head_dim, config.width),
name=_name("attn_vec_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
)
if config.projection_lora is not None:
out_einsum = LoRAEinsum(
out_einsum,
config.projection_lora,
merge_eqn="NHNL,NLD->NHD",
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=(-4, -3), out_axis=(-2, -1)),
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
)
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
start = end
else:
out.append(None)
return out
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
configs: Sequence[Config]
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
@nn.compact
def __call__(self, xs, unused_scan_arg, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
attn = Attention(configs=self.configs, name="attn")
pre_attn = []
for i, x in enumerate(xs):
if x is not None:
x = RMSNorm(name=_name("pre_attention_norm", i))(x) # noqa: PLW2901
pre_attn.append(x)
post_attn = attn(pre_attn, positions, attn_mask, decode)
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
xs = jax.tree.map(lambda x, y: x + y, xs, post_attn)
out = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
x = RMSNorm(name=_name("pre_ffw_norm", i))(x) # noqa: PLW2901
x = FeedForward( # noqa: PLW2901
features=config.width,
hidden_dim=config.mlp_dim,
name=_name("mlp", i),
)(x)
out.append(x)
out = jax.tree.map(lambda x: drop(x, deterministic), out)
xs = jax.tree.map(lambda x, y: x + y, xs, out)
return xs, unused_scan_arg
@at.typecheck
class Module(nn.Module):
"""Transformer model, supporting a mixture of different weights for different tokens."""
configs: Sequence[Config] # list of configs, one for each expert
embed_dtype: str
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
@nn.compact
@at.typecheck
def __call__(
self,
*,
tokens: at.Int[at.Array, "b t"] | None,
# list of token arrays, one for each expert, or None if that expert should not be run
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None] | None,
positions: at.Int[at.Array, "b t"] | None = None,
mask: at.Bool[at.Array, "b t s"] | None = None,
decode: bool = False,
deterministic: bool = True,
) -> at.Float[at.Array, "b t d"] | Sequence[at.Float[at.Array, "b _t _d"] | None]:
# all experts must have the same depth
assert all(config.depth == self.configs[0].depth for config in self.configs)
# embedder for first expert only
embedder = Embedder(
vocab_size=PALIGEMMA_VOCAB_SIZE,
embed_dim=self.configs[0].width,
name="embedder",
)
if tokens is not None:
# embed only
assert embedded is None, "Cannot pass both tokens and embedded"
return embedder.encode(tokens).astype(self.embed_dtype)
assert embedded is not None
assert positions is not None
assert mask is not None
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
mask = jnp.asarray(mask)[:, None, :, :]
block_cls = nn.remat(
Block,
prevent_cse=False,
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
policy=jax.checkpoint_policies.nothing_saveable,
)
block = nn.scan(
block_cls,
# cache has axis 1 since we want leading dimension to be batch size.
variable_axes={"params": 0, "cache": 1},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.configs[0].depth,
)(
parent=self.scope.push("layers"),
configs=self.configs,
dropout=self.dropout,
dropout_bdims=self.dropout_bdims,
)
embedded, _ = block(embedded, (), positions, mask, decode, deterministic)
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
return [RMSNorm(name=_name("final_norm", i))(e) if e is not None else e for i, e in enumerate(embedded)]
def _apply_rope(x, *, positions, max_wavelength=10_000):
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
assert radians.dtype == jnp.float32
# radians.shape = [...,L,1,d=D/2]
sin, cos = jnp.sin(radians), jnp.cos(radians)
x1, x2 = jnp.split(x, 2, axis=-1)
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
assert res.dtype == jnp.float32
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
# here.
return res.astype(x.dtype)
def _name(name, i):
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
# and the action expert.
if i == 0:
return name
return f"{name}_{i}"

View File

@@ -0,0 +1,97 @@
import chex
import flax.linen as nn
import jax
import pytest
import openpi.models.gemma as gemma
def get_annotation_to_dim_size() -> dict[str, int]:
return {
"B": 8,
"T": 13,
"S": 7,
"N": 4,
"M": 4,
"K": 2,
"H": 48,
"D": 64,
}
def eqn_to_shape(eqn: str, annotation_to_dim_size: dict[str, int]) -> tuple[tuple[int, ...], ...]:
(lhs_part_0, lhs_part_1), _ = gemma._parse_einops_eqn(eqn) # noqa: SLF001
return tuple(int(ann) if ann.isdigit() else annotation_to_dim_size[ann] for ann in lhs_part_0), tuple(
int(ann) if ann.isdigit() else annotation_to_dim_size[ann] for ann in lhs_part_1
)
@pytest.mark.parametrize(
("eqn", "lora_annotation"),
[
("BSD,3KDH->3BSKH", "3KDL,3KLKH->3KDH"),
("BTD,NDH->BTNH", "NDL,NLNH->NDH"),
("BSD,2KDH->2BSKH", "2KDL,2KLKH->2KDH"),
("BTNH,NHD->BTD", "NHNL,NLD->NHD"),
],
)
def test_lora_einsum_equivalent_to_original(eqn: str, lora_annotation: str):
annotation_to_dim_size = get_annotation_to_dim_size()
x_shape, w_shape = eqn_to_shape(eqn, annotation_to_dim_size)
einsum = gemma.Einsum(shape=w_shape, name="einsum", init_fn=nn.initializers.lecun_normal())
lora_einsum = gemma.LoRAEinsum(
einsum,
gemma.LoRAConfig(rank=4, alpha=4.0),
lora_annotation,
nn.initializers.zeros_init(),
nn.initializers.zeros_init(),
)
x = jax.random.normal(jax.random.key(0), x_shape)
def module_call(instance, x):
return instance(eqn, x)
einsum_variables = einsum.init(jax.random.key(0), x, method=module_call)
lora_einsum_variables = lora_einsum.init(jax.random.key(0), x, method=module_call)
# Copy over the weights from the original einsum to the lora einsum since the initialization order is
# not the same.
lora_einsum_variables["params"]["w"] = einsum_variables["params"]["w"]
y = einsum.apply(einsum_variables, x, rngs={}, method=module_call)
y_lora = lora_einsum.apply(lora_einsum_variables, x, rngs={}, method=module_call)
chex.assert_trees_all_close(y, y_lora)
@pytest.mark.parametrize(
("eqn", "lora_annotation"),
[
("BSD,3KDH->3BSKH", "3KDL,3KLKH->3KDH"),
("BTD,NDH->BTNH", "NDL,NLNH->NDH"),
("BSD,2KDH->2BSKH", "2KDL,2KLKH->2KDH"),
("BTNH,NHD->BTD", "NHNL,NLD->NHD"),
],
)
def test_lora_einsum_param_merge_works(eqn: str, lora_annotation: str):
annotation_to_dim_size = get_annotation_to_dim_size()
x_shape, w_shape = eqn_to_shape(eqn, annotation_to_dim_size)
einsum = gemma.Einsum(shape=w_shape, name="einsum", init_fn=nn.initializers.lecun_normal())
lora_einsum = gemma.LoRAEinsum(
einsum,
gemma.LoRAConfig(rank=4, alpha=4.0),
lora_annotation,
nn.initializers.lecun_normal(),
nn.initializers.lecun_normal(),
)
x = jax.random.uniform(jax.random.key(0), x_shape)
def module_call(instance, x):
return instance(eqn, x)
lora_einsum_variables = lora_einsum.init(jax.random.key(0), x, method=module_call)
einsum_variables = gemma.merge_lora_params(lora_einsum_variables, lambda x: lora_annotation)
y = einsum.apply(einsum_variables, x, rngs={}, method=module_call)
y_lora = lora_einsum.apply(lora_einsum_variables, x, rngs={}, method=module_call)
chex.assert_trees_all_close(y, y_lora, atol=0.001)

260
src/openpi/models/model.py Normal file
View File

@@ -0,0 +1,260 @@
import abc
from collections.abc import Sequence
import dataclasses
import logging
import pathlib
import augmax
from flax import struct
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from typing_extensions import override
from openpi.models import common
from openpi.shared import image_tools
import openpi.shared.array_typing as at
logger = logging.getLogger("openpi")
# The model always expects these images
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
# This may need change if we release a small model.
IMAGE_RESOLUTION = (224, 224)
def preprocess_observation(
rng: at.KeyArrayLike,
observation: common.Observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
) -> common.Observation:
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for augmax.
image = image / 2.0 + 0.5
transforms = []
if "wrist" not in key:
height, width = image.shape[1:3]
transforms += [
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
augmax.Resize(width, height),
augmax.Rotate((-5, 5)),
]
transforms += [
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
]
sub_rngs = jax.random.split(rng, image.shape[0])
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
# Back to [-1, 1].
image = image * 2.0 - 1.0
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
else:
out_masks[key] = jnp.asarray(observation.image_masks[key])
return common.Observation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
)
@struct.dataclass
class BaseModel(abc.ABC):
# Action space dimension.
action_dim: int = struct.field(pytree_node=False)
# Action sequence length.
action_horizon: int = struct.field(pytree_node=False)
# Tokenized prompt maximum length.
max_token_len: int = struct.field(pytree_node=False)
@abc.abstractmethod
def compute_loss(
self,
rng: at.KeyArrayLike,
observation: common.Observation,
actions: common.Actions,
*,
train: bool = False,
params: at.Params | None = None,
) -> at.Float[at.Array, "*b ah"]: ...
@abc.abstractmethod
def sample_actions(
self,
rng: at.KeyArrayLike,
observation: common.Observation,
**sample_kwargs,
) -> common.Actions: ...
@struct.dataclass
class Model(BaseModel):
module: common.BaseModule = struct.field(pytree_node=False)
params: at.Params | None = None
def init_params(self, rng: at.KeyArrayLike, observation: common.Observation, actions: common.Actions) -> at.Params:
"""Initialize and return the parameters by tracing the module's `compute_loss` function."""
preprocess_rng, init_rng = jax.random.split(rng)
obs = preprocess_observation(preprocess_rng, observation)
return self.module.init(init_rng, obs, actions, method=self.module.compute_loss)["params"]
@at.typecheck
@override
def compute_loss(
self,
rng: at.KeyArrayLike,
observation: common.Observation,
actions: common.Actions,
params: at.Params | None = None,
*,
train: bool = False,
) -> at.Float[at.Array, ""]:
if params is None:
if self.params is None:
raise ValueError(
"No parameters found. Either bind the model to parameters using `set_params` or provide params directly."
)
params = self.params
loss_rng, preprocess_rng = jax.random.split(rng)
obs = preprocess_observation(preprocess_rng, observation, train=train)
loss_args = (obs, actions)
return jnp.mean(
self.module.apply({"params": params}, *loss_args, rngs={"loss": loss_rng}, method=self.module.compute_loss) # type: ignore
)
@jax.jit
@at.typecheck
@override
def sample_actions(
self,
rng: at.KeyArrayLike,
observation: common.Observation,
**sample_kwargs,
) -> common.Actions:
if self.params is None:
raise ValueError(
"No parameters found. Bind the model to parameters using `set_params` before calling `sample_actions`."
)
preprocess_rng, sample_rng = jax.random.split(rng)
obs = preprocess_observation(preprocess_rng, observation)
sample_args = (self.action_horizon, self.action_dim, obs)
actions, _ = self.module.apply(
{"params": self.params},
*sample_args,
rngs={"sample": sample_rng},
method=self.module.sample_actions,
mutable=["cache"],
**sample_kwargs,
)
return actions
def set_params(self, params: at.Params) -> "Model":
"""Returns a copy of the model bound to `params`."""
return dataclasses.replace(self, params=params)
def fake_obs(self, batch_size: int = 1) -> common.Observation:
observation_spec, _ = create_inputs_spec(self, batch_size=batch_size)
return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), observation_spec)
def fake_act(self, batch_size: int = 1) -> common.Actions:
_, action_spec = create_inputs_spec(self, batch_size=batch_size)
return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), action_spec)
def restore_params(
params_path: pathlib.Path | str,
*,
dtype: jnp.dtype | None = None,
sharding: jax.sharding.Sharding | None = None,
) -> at.Params:
"""Restores unstructured params PyTree from a checkpoint. This works with checkpoints saved with `save_state` during
openpi training (see `training/checkpoints.py`) as well as pre-trained checkpoints released for openpi.
"""
params_path = pathlib.Path(params_path).resolve()
if not params_path.exists():
raise FileNotFoundError(f"Model params not found at: {params_path}")
restore_type = np.ndarray if sharding is None else jax.Array
with ocp.PyTreeCheckpointer() as ckptr:
metadata = ckptr.metadata(params_path)
# Use EMA params if they exist, otherwise regular params. See `training.utils.TrainState`.
params_name = "ema_params" if metadata.get("ema_params") is not None else "params"
item = {params_name: metadata[params_name]}
return ckptr.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree.map(
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
),
transforms={}, # required to load a partial PyTree (e.g., only params from a full TrainState)
),
)[params_name]
def create_inputs_spec(model: Model, *, batch_size: int = 1) -> tuple[common.Observation, at.Float[at.Array, "ah ad"]]:
image_spec = jax.ShapeDtypeStruct([batch_size, 224, 224, 3], jnp.float32)
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
with at.disable_typechecking():
observation_spec = common.Observation(
images={
"base_0_rgb": image_spec,
"left_wrist_0_rgb": image_spec,
"right_wrist_0_rgb": image_spec,
},
image_masks={
"base_0_rgb": image_mask_spec,
"left_wrist_0_rgb": image_mask_spec,
"right_wrist_0_rgb": image_mask_spec,
},
state=jax.ShapeDtypeStruct([batch_size, model.action_dim], jnp.float32),
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, model.max_token_len], jnp.int32),
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, model.max_token_len], jnp.int32),
)
action_spec = jax.ShapeDtypeStruct([batch_size, model.action_horizon, model.action_dim], jnp.float32)
return observation_spec, action_spec

View File

@@ -0,0 +1,47 @@
import jax
import jax.numpy as jnp
from openpi.models import model as _model
from openpi.models import pi0
from openpi.shared import download
def make_from_spec(spec: jax.ShapeDtypeStruct):
return jnp.zeros(shape=spec.shape, dtype=spec.dtype)
def create_pi0_model():
return _model.Model(module=pi0.Module(), action_dim=24, action_horizon=50, max_token_len=48)
def test_model():
model = create_pi0_model()
batch_size = 2
obs, act = model.fake_obs(batch_size), model.fake_act(batch_size)
rng = jax.random.key(0)
model = model.set_params(model.init_params(rng, obs, act))
loss = model.compute_loss(rng, obs, act)
assert loss.shape == ()
actions = model.sample_actions(rng, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_model_restore():
model = create_pi0_model()
batch_size = 2
obs, act = model.fake_obs(batch_size), model.fake_act(batch_size)
params = _model.restore_params(download.maybe_download("s3://openpi-assets/exported/pi0_aloha_sim/model"))
model = model.set_params(params)
rng = jax.random.key(0)
loss = model.compute_loss(rng, obs, act)
assert loss.shape == ()
actions = model.sample_actions(rng, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)

306
src/openpi/models/pi0.py Normal file
View File

@@ -0,0 +1,306 @@
import logging
from typing import Literal
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from typing_extensions import override
from openpi.models import common
import openpi.models.gemma as _gemma
import openpi.models.siglip as _siglip
from openpi.shared import array_typing as at
logger = logging.getLogger("openpi")
def make_attn_mask(input_mask, mask_ar):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
cumsum = jnp.cumsum(mask_ar, axis=1)
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
return jnp.logical_and(attn_mask, valid_mask)
@at.typecheck
def posemb_sincos(
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if embedding_dim % 2 != 0:
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij",
pos,
1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
class Module(common.BaseModule):
"""Pi0 module (transfusion-style decoder-only flow matching)."""
dtype: str = "bfloat16"
paligemma_variant: _gemma.Variant = "gemma_2b"
action_expert_variant: _gemma.Variant = "gemma_300m"
@at.typecheck
@override
def compute_loss(
self,
obs: common.Observation,
target_actions: common.Actions,
*,
timestep: at.Float[at.Array, " b"] | None = None,
) -> at.Float[at.Array, "b ah"]:
batch_size = target_actions.shape[0]
noise = jax.random.normal(self.make_rng("loss"), target_actions.shape)
if timestep is None:
timestep = jax.random.beta(self.make_rng("loss"), 1.5, 1, (batch_size,)) * 0.999 + 0.001
time_expanded = timestep[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * target_actions
u_t = noise - target_actions
pred = self.forward(obs, x_t, timestep, mode="train")
return jnp.mean(jnp.square(pred - u_t), axis=2)
@at.typecheck
@override
def sample_actions(
self,
action_horizon: int,
action_dim: int,
obs: common.Observation,
*,
noise: at.Float[at.Array, "b ah ad"] | None = None,
num_steps: int | at.Int[at.Array, ""] = 10,
) -> common.Actions:
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
dt = -1.0 / num_steps
batch_size = obs.state.shape[0]
if noise is None:
noise = jax.random.normal(self.make_rng("sample"), (batch_size, action_horizon, action_dim))
# first fill KV cache (in-place)
self.forward(obs, None, None, mode="fill_cache")
@at.typecheck
def sample_step(
module: Module,
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
) -> tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]]:
x_t, time = carry
time_batched = einops.repeat(time, "-> b", b=batch_size)
v_t = module.forward(obs, x_t, time_batched, mode="decode")
# Euler step
x_tilde = x_t + dt * v_t
return x_tilde, time + dt
@at.typecheck
def cond_fn(
module: Module,
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
) -> at.Bool[at.Array, ""]:
x_t, time = carry
# robust to floating-point error
return time >= -dt / 2
time = jnp.array(1.0, dtype=jnp.float32)
x_0, _ = nn.while_loop(cond_fn, sample_step, self, (noise, time))
return x_0
@nn.compact
@at.typecheck
def forward(
self,
obs: common.Observation,
noisy_actions: at.Float[at.Array, "b ah ad"] | None,
timestep: at.Float[at.Array, " b"] | None,
mode: Literal["train", "fill_cache", "decode"],
):
"""Main forward pass of the transformer. It operates in 3 modes:
1. mode="train": This is full forward pass, used during training.
2. mode="fill_cache": This is used to compute the KV cache for the prefix (image + language inputs).
3. mode="decode": This is used to perform a flow matching integration step; it uses the KV cache computed in the
fill_cache mode.
"""
paligemma_scope = self.scope.push("PaliGemma")
llm_scope = paligemma_scope.push("llm")
img_scope = paligemma_scope.push("img")
paligemma_config = _gemma.get_config(self.paligemma_variant)
action_expert_config = _gemma.get_config(self.action_expert_variant)
gemma = _gemma.Module(
configs=[paligemma_config, action_expert_config],
embed_dtype=self.dtype,
parent=llm_scope,
)
siglip = _siglip.Module(
num_classes=paligemma_config.width,
variant="So400m/14",
pool_type="none",
scan=True,
dtype_mm=self.dtype,
parent=img_scope,
)
batch_size = obs.state.shape[0]
input_mask: list[at.Bool[at.Array, "b s"]] = []
ar_mask: list[int] = []
if mode in ["train", "fill_cache"]:
prefix_tokens: list[at.Float[at.Array, "b s emb"]] = []
# embed images
for name in obs.images:
image_tokens, _ = siglip(obs.images[name], train=False)
prefix_tokens.append(image_tokens)
input_mask.append(
einops.repeat(
obs.image_masks[name],
"b -> b s",
s=image_tokens.shape[1],
)
)
# image tokens attend to each other
ar_mask += [0] * image_tokens.shape[1]
# add language (aka tokenized inputs)
if obs.tokenized_prompt is not None:
# run gemma in embed-only mode
tokenized_inputs = gemma(tokens=obs.tokenized_prompt, embedded=None)
prefix_tokens.append(tokenized_inputs)
input_mask.append(obs.tokenized_prompt_mask)
# full attention between image and language inputs
ar_mask += [0] * tokenized_inputs.shape[1]
prefix_tokens = jnp.concatenate(prefix_tokens, axis=1)
prefix_len = prefix_tokens.shape[1]
if mode in ["train", "decode"]:
assert noisy_actions is not None
suffix_tokens: list[at.Float[at.Array, "b s emb"]] = []
# add a single state token
state_token = nn.Dense(action_expert_config.width, name="state_proj")(obs.state)
suffix_tokens.append(state_token[:, None, :])
input_mask.append(jnp.ones((batch_size, 1), dtype=jnp.bool_))
# image/language inputs do not attend to state or actions
ar_mask += [1]
action_horizon = noisy_actions.shape[1]
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = posemb_sincos(timestep, action_expert_config.width, min_period=4e-3, max_period=4.0)
# mix timestep + action information using an MLP
action_tokens = nn.Dense(action_expert_config.width, name="action_in_proj")(noisy_actions)
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=action_horizon)
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
action_time_tokens = nn.Dense(action_expert_config.width, name="action_time_mlp_in")(action_time_tokens)
action_time_tokens = nn.swish(action_time_tokens)
action_time_tokens = nn.Dense(action_expert_config.width, name="action_time_mlp_out")(action_time_tokens)
# add to input tokens
suffix_tokens.append(action_time_tokens)
input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
# image/language/state inputs do not attend to action tokens
ar_mask += [1] + ([0] * (action_horizon - 1))
suffix_tokens = jnp.concatenate(suffix_tokens, axis=1)
suffix_len = suffix_tokens.shape[1]
if mode == "train":
# due to prefix-lm decoding, it is very important that the prefix cannot attend to the suffix
assert ar_mask[prefix_len] == 1
# create attention mask (shared between prefix and suffix)
input_mask = jnp.concatenate(input_mask, axis=1)
ar_mask = np.array(ar_mask, dtype=np.int32)
ar_mask = einops.repeat(ar_mask, "s -> b s", b=batch_size)
attn_mask = make_attn_mask(input_mask, ar_mask)
if mode in ["train", "decode"]:
out_proj = nn.Dense(noisy_actions.shape[-1], name="action_out_proj")
if mode == "train":
# full forward pass on prefix + suffix at once
positions = jnp.cumsum(input_mask, axis=1) - 1
_, out = gemma(
tokens=None,
embedded=[prefix_tokens, suffix_tokens],
mask=attn_mask,
positions=positions,
decode=False,
)
return out_proj(out[:, -action_horizon:])
if mode == "fill_cache":
# fill the KV cache using the prefix tokens. this mutates the "cache" variable in place.
self.put_variable("cache", "prefix_mask", input_mask.astype(bool))
positions = jnp.cumsum(input_mask, axis=-1) - 1
gemma(
tokens=None,
embedded=[prefix_tokens, None],
positions=positions,
mask=attn_mask,
decode=True,
)
return None
if mode == "decode":
# decode using the existing KV cache
prefix_len = gemma.variables["cache"]["layers"]["attn"]["k_cache"].shape[2]
# `prefix_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
# prefix tokens
prefix_mask = self.get_variable("cache", "prefix_mask")
assert prefix_mask.shape == (batch_size, prefix_len)
prefix_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_len)
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
combined_mask = jnp.concatenate([prefix_mask, attn_mask], axis=-1)
assert combined_mask.shape == (
batch_size,
suffix_len,
prefix_len + suffix_len,
)
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
positions = (
jnp.sum(self.get_variable("cache", "prefix_mask"), axis=-1)[:, None]
+ jnp.cumsum(input_mask, axis=-1)
- 1
)
unused, out = gemma(
tokens=None,
embedded=[None, suffix_tokens],
mask=combined_mask,
positions=positions,
decode=True,
)
assert unused is None
return out_proj(out[:, -action_horizon:])
raise ValueError(f"Invalid mode: {mode}")

View File

@@ -0,0 +1,191 @@
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
from typing_extensions import override
from openpi.models import common
import openpi.models.transformer as _transformer
import openpi.models.vit as _vit
from openpi.shared import array_typing as at
@at.typecheck
def posemb_sincos(
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
) -> at.Float[at.Array, "b {embedding_dim}"]:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if embedding_dim % 2 != 0:
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
period = min_period * (max_period / min_period) ** fraction
sinusoid_input = jnp.einsum(
"i,j->ij",
pos,
1.0 / period * 2 * jnp.pi,
precision=jax.lax.Precision.HIGHEST,
)
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
class ViTEncoder(nn.Module):
"""ViT encoder from the Google vision_transformer codebase."""
dtype: str = "bfloat16"
@nn.compact
@at.typecheck
def __call__(self, image: at.Float[at.Array, "b h w c"]) -> at.Float[at.Array, "b seq emb"]:
vit = _vit.VisionTransformer(
name="VisionTransformer",
dtype=self.dtype,
# Removes class token.
num_classes=None,
classifier="unpooled",
# R26+ViT-S_32 config.
patches=ml_collections.ConfigDict({"size": (1, 1)}),
transformer=ml_collections.ConfigDict({"mlp_dim": 1536, "num_heads": 6, "num_layers": 12}),
hidden_size=384,
resnet=ml_collections.ConfigDict({"num_layers": (2, 2, 2, 2), "width_factor": 1}),
)
# VisionTransformer expects images in [0, 1] range.
image = (image + 1) / 2
return vit(image, train=False)
class Encoder(nn.Module):
"""Transformer encoder that combines ViTEncoders for each image, plus state information."""
variant: _transformer.Variant = "small"
dtype: str = "bfloat16"
@nn.compact
@at.typecheck
def __call__(self, obs: common.Observation) -> _transformer.TokenSequence:
transformer, embed_dim = _transformer.get_variant(self.variant, dtype=self.dtype)
image_tokens: list[_transformer.TokenSequence] = []
for name in obs.images:
zimg = ViTEncoder(name=f"backbone_{name}", dtype=self.dtype)(obs.images[name])
zimg = nn.Dense(embed_dim, name=f"proj_{name}")(zimg)
posemb = self.param(f"posemb_image_{name}", nn.initializers.normal(0.02), (embed_dim,))
image_tokens.append(
_transformer.TokenSequence(
tokens=zimg,
pos=jnp.broadcast_to(posemb, zimg.shape),
mask=jnp.broadcast_to(obs.image_masks[name][:, None], zimg.shape[:-1]),
)
)
state_token = _transformer.TokenSequence(
tokens=nn.Dense(embed_dim, name="state_proj")(obs.state)[:, None],
pos=self.param("posemb_state", nn.initializers.normal(0.02), (embed_dim,))[None],
)
input_tokens = _transformer.TokenSequence.concatenate(*image_tokens, state_token)
return transformer(input_tokens)
class Decoder(nn.Module):
variant: _transformer.Variant = "small"
dtype: str = "bfloat16"
@nn.compact
@at.typecheck
def __call__(
self,
noisy_actions: at.Float[at.Array, "b ah ad"],
timestep: at.Float[at.Array, " b"],
cond_tokens: _transformer.TokenSequence,
) -> at.Float[at.Array, "b ah ad"]:
transformer, embed_dim = _transformer.get_variant(self.variant, dtype=self.dtype)
tokens = _transformer.TokenSequence(
# project actions to embedding dimension
tokens=nn.Dense(embed_dim, name="in_proj")(noisy_actions),
# use learned positional embedding for actions
pos=self.param("posemb_actions", nn.initializers.normal(0.02), (noisy_actions.shape[1], embed_dim)),
)
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = posemb_sincos(timestep, embed_dim, min_period=4e-3, max_period=4.0)
# time MLP
time_emb = nn.Dense(embed_dim, name="time_mlp_in")(time_emb)
time_emb = nn.swish(time_emb)
time_emb = nn.Dense(embed_dim, name="time_mlp_out")(time_emb)
output_tokens = transformer(tokens, xattn_cond=cond_tokens, adaln_cond=time_emb)
return nn.Dense(noisy_actions.shape[-1], name="out_proj")(output_tokens.tokens)
class Module(common.BaseModule):
encoder: Encoder = Encoder()
decoder: Decoder = Decoder()
@at.typecheck
@override
def compute_loss(
self,
obs: common.Observation,
target_actions: common.Actions,
*,
timestep: at.Float[at.Array, " b"] | None = None,
) -> at.Float[at.Array, "b ah"]:
batch_size = target_actions.shape[0]
noise = jax.random.normal(self.make_rng("loss"), target_actions.shape)
if timestep is None:
timestep = jax.random.beta(self.make_rng("loss"), 1.5, 1, (batch_size,)) * 0.999 + 0.001
time_expanded = timestep[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * target_actions
u_t = noise - target_actions
pred = self.decoder(x_t, timestep, self.encoder(obs))
return jnp.mean(jnp.square(pred - u_t), axis=2)
@at.typecheck
@override
def sample_actions(
self,
action_horizon: int,
action_dim: int,
obs: common.Observation,
*,
noise: at.Float[at.Array, "b ah ad"] | None = None,
num_steps: int = 10,
) -> common.Actions:
dt = -1.0 / num_steps
batch_size = obs.state.shape[0]
if noise is None:
noise = jax.random.normal(self.make_rng("sample"), (batch_size, action_horizon, action_dim))
cond_tokens = self.encoder(obs)
@at.typecheck
def sample_step(
module: Module,
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
) -> tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]]:
x_t, time = carry
time_batched = einops.repeat(time, "-> b", b=batch_size)
v_t = module.decoder(x_t, time_batched, cond_tokens)
# Euler step
x_tilde = x_t + dt * v_t
return x_tilde, time + dt
@at.typecheck
def cond_fn(
module: Module,
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
) -> at.Bool[at.Array, ""]:
x_t, time = carry
# robust to floating-point error
return time >= -dt / 2
time = jnp.array(1.0, dtype=jnp.float32)
x_0, _ = nn.while_loop(cond_fn, sample_step, self, (noise, time))
return x_0

View File

@@ -0,0 +1,82 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet implementation copied from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_resnet.py."""
from collections.abc import Callable, Sequence
from typing import TypeVar
from flax import linen as nn
import jax.numpy as jnp
T = TypeVar("T")
def weight_standardize(w, axis, eps):
"""Subtracts mean and divides by standard deviation."""
w = w - jnp.mean(w, axis=axis)
return w / (jnp.std(w, axis=axis) + eps)
class StdConv(nn.Conv):
"""Convolution with weight standardization."""
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
param = super().param(name, init_fn, *init_args)
if name == "kernel":
param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5)
return param
class ResidualUnit(nn.Module):
"""Bottleneck ResNet block."""
features: int
strides: Sequence[int] = (1, 1)
@nn.compact
def __call__(self, x):
needs_projection = x.shape[-1] != self.features * 4 or self.strides != (1, 1)
residual = x
if needs_projection:
residual = StdConv(
features=self.features * 4, kernel_size=(1, 1), strides=self.strides, use_bias=False, name="conv_proj"
)(residual)
residual = nn.GroupNorm(name="gn_proj")(residual)
y = StdConv(features=self.features, kernel_size=(1, 1), use_bias=False, name="conv1")(x)
y = nn.GroupNorm(name="gn1")(y)
y = nn.relu(y)
y = StdConv(features=self.features, kernel_size=(3, 3), strides=self.strides, use_bias=False, name="conv2")(y)
y = nn.GroupNorm(name="gn2")(y)
y = nn.relu(y)
y = StdConv(features=self.features * 4, kernel_size=(1, 1), use_bias=False, name="conv3")(y)
y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y)
return nn.relu(residual + y)
class ResNetStage(nn.Module):
"""A ResNet stage."""
block_size: Sequence[int]
nout: int
first_stride: Sequence[int]
@nn.compact
def __call__(self, x):
x = ResidualUnit(self.nout, strides=self.first_stride, name="unit1")(x)
for i in range(1, self.block_size):
x = ResidualUnit(self.nout, strides=(1, 1), name=f"unit{i + 1}")(x)
return x

372
src/openpi/models/siglip.py Normal file
View File

@@ -0,0 +1,372 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
from collections.abc import Sequence
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
"""Follows the MoCo v3 logic."""
y, x = jnp.mgrid[:h, :w]
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
omega = jnp.arange(width // 4) / (width // 4 - 1)
omega = 1.0 / (temperature**omega)
y = jnp.einsum("m,d->md", y.flatten(), omega)
x = jnp.einsum("m,d->md", x.flatten(), omega)
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
return jnp.asarray(pe, dtype)[None, :, :]
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
if typ == "learn":
return self.param(
name,
nn.initializers.normal(stddev=1 / np.sqrt(width)),
(1, np.prod(seqshape), width),
dtype,
)
if typ == "sincos2d":
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
raise ValueError(f"Unknown posemb type: {typ}")
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int | None = None # Defaults to 4x input dim
dropout: float = 0.0
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
"""Applies Transformer MlpBlock module."""
inits = {
"kernel_init": nn.initializers.xavier_uniform(),
"bias_init": nn.initializers.normal(stddev=1e-6),
}
_, _, d = x.shape # n,l,d
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout)(x, deterministic)
return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
class Encoder1DBlock(nn.Module):
"""Single transformer encoder block (MHSA + MLP)."""
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dropout: float = 0.0
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
out = {}
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
y = out["sa"] = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
kernel_init=nn.initializers.xavier_uniform(),
deterministic=deterministic,
dtype=self.dtype_mm,
)(y, y)
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
y = nn.Dropout(rate=self.dropout)(y, deterministic)
x = out["+sa"] = x + y
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
y = out["mlp"] = MlpBlock(
mlp_dim=self.mlp_dim,
dropout=self.dropout,
dtype_mm=self.dtype_mm,
)(y, deterministic)
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
y = nn.Dropout(rate=self.dropout)(y, deterministic)
x = out["+mlp"] = x + y
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
return x, out
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
depth: int
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dropout: float = 0.0
scan: bool = False
remat_policy: str = "nothing_saveable"
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x, deterministic=True): # noqa: FBT002
out = {}
if self.scan:
block = nn.remat(
Encoder1DBlock,
prevent_cse=False,
static_argnums=(2,), # 0=self, 2=deterministic
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
)
x, scan_out = nn.scan(
block,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.depth,
)(
name="encoderblock",
dtype_mm=self.dtype_mm,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
)(x, deterministic)
for lyr in range(self.depth):
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
else:
# Input Encoder
for lyr in range(self.depth):
block_cur = Encoder1DBlock(
name=f"encoderblock_{lyr}",
dtype_mm=self.dtype_mm,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
)
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
out["pre_ln"] = x # Alias for last block, but without the number in it.
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
class MAPHead(nn.Module):
"""Multihead Attention Pooling."""
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
dtype_mm: str = "float32"
@nn.compact
def __call__(self, x):
n, _, d = x.shape # n,l,d
probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
probe = jnp.tile(probe, [n, 1, 1])
x = nn.MultiHeadDotProductAttention(
num_heads=self.num_heads,
dtype=self.dtype_mm,
kernel_init=nn.initializers.xavier_uniform(),
)(probe, x)
# TODO: dropout on head?
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
return x[:, 0]
class _Module(nn.Module):
"""ViT model."""
num_classes: int | None = None
patch_size: Sequence[int] = (16, 16)
width: int = 768
depth: int = 12
mlp_dim: int | None = None # Defaults to 4x input dim
num_heads: int = 12
posemb: str = "learn" # Can also be "sincos2d"
rep_size: int | bool = False
dropout: float = 0.0
pool_type: str = "gap" # Can also be "map" or "tok"
head_zeroinit: bool = True
scan: bool = False
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
remat_policy: str = "nothing_saveable"
dtype_mm: str = "float32"
@nn.compact
def __call__(self, image, *, train=False):
out = {}
# Kevin edit: do patch extraction and posemb in float32,
# because I feel like it's a bit safer.
image = jnp.asarray(image, jnp.float32)
# Patch extraction
x = out["stem"] = nn.Conv(
self.width,
self.patch_size,
strides=self.patch_size,
padding="VALID",
name="embedding",
dtype=jnp.float32,
)(image)
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# Add posemb before adding extra token.
x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
if self.pool_type == "tok":
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
n, _, c = x.shape # n,l,d
x = nn.Dropout(rate=self.dropout)(x, not train)
# Kevin edit: now cast back to dtype_mm (potentially half precision)
x = x.astype(self.dtype_mm)
x, out["encoder"] = Encoder(
depth=self.depth,
mlp_dim=self.mlp_dim,
num_heads=self.num_heads,
dropout=self.dropout,
scan=self.scan,
remat_policy=self.remat_policy,
dtype_mm=self.dtype_mm,
name="Transformer",
)(x, deterministic=not train)
encoded = out["encoded"] = x
if self.pool_type == "map":
x = out["head_input"] = MAPHead(
num_heads=self.num_heads,
mlp_dim=self.mlp_dim,
dtype=self.dtype_mm,
)(x)
elif self.pool_type == "gap":
x = out["head_input"] = jnp.mean(x, axis=1)
elif self.pool_type == "0":
x = out["head_input"] = x[:, 0]
elif self.pool_type == "tok":
x = out["head_input"] = x[:, 0]
encoded = encoded[:, 1:]
elif self.pool_type == "none":
pass
else:
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
x_2d = jnp.reshape(encoded, [n, h, w, -1])
if self.rep_size:
rep_size = self.width if self.rep_size is True else self.rep_size
hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
# NOTE: In the past we did not include tanh in pre_logits.
# For few-shot, it should not matter much, as it whitens anyways.
x_2d = nn.tanh(hid(x_2d))
x = nn.tanh(hid(x))
out["pre_logits_2d"] = x_2d
out["pre_logits"] = x
if self.num_classes:
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
x_2d = out["logits_2d"] = head(x_2d)
x = out["logits"] = head(x)
return x, out
def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
"""Factory function, because linen really don't like what I'm doing!"""
return _Module(num_classes, **{**decode_variant(variant), **kw})
def decode_variant(variant):
"""Converts a string like "B" or "B/32" into a params dict."""
if variant is None:
return {}
v, patch = variant, {}
if "/" in variant:
v, patch = variant.split("/")
patch = {"patch_size": (int(patch), int(patch))}
return {
# pylint:disable=line-too-long
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
"width": {
"mu": 32,
"Ti": 192,
"S": 384,
"M": 512,
"B": 768,
"L": 1024,
"So400m": 1152,
"H": 1280,
"g": 1408,
"g-opt": 1536,
"G": 1664,
"G-opt": 1536,
"e": 1792,
}[v],
"depth": {
"mu": 1,
"Ti": 12,
"S": 12,
"M": 12,
"B": 12,
"L": 24,
"So400m": 27,
"H": 32,
"g": 40,
"g-opt": 40,
"G": 48,
"G-opt": 48,
"e": 56,
}[v],
"mlp_dim": {
"mu": 128,
"Ti": 768,
"S": 1536,
"M": 2048,
"B": 3072,
"L": 4096,
"So400m": 4304,
"H": 5120,
"g": 6144,
"g-opt": 6144,
"G": 8192,
"G-opt": 8192,
"e": 15360,
}[v],
"num_heads": {
"mu": 2,
"Ti": 3,
"S": 6,
"M": 8,
"B": 12,
"L": 16,
"So400m": 16,
"H": 16,
"g": 16,
"g-opt": 16,
"G": 16,
"G-opt": 16,
"e": 16,
}[v],
# pylint:enable=line-too-long
**patch,
}

View File

@@ -0,0 +1,51 @@
import abc
import numpy as np
import sentencepiece
from typing_extensions import override
import openpi.shared.download as download
class Tokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(self, batch: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Tokenize a batch of prompts.
Args:
batch: A batch of text prompts to tokenize.
Returns:
A tuple containing the tokenized prompts and the corresponding masks.
"""
class PaligemmaTokenizer(Tokenizer):
def __init__(self, max_len: int = 48):
self._max_len = max_len
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
with path.open("rb") as f:
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
@override
def tokenize(self, batch: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
batch_tokens, batch_masks = [], []
for text in batch:
cleaned_text = text.lower().strip().replace("_", " ").replace("\n", " ")
# tokenize "\n" separately as the "start of answer" token
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
tokens_len = len(tokens)
if tokens_len < self._max_len:
padding = [0] * (self._max_len - tokens_len)
mask = [1] * tokens_len + padding
tokens = tokens + padding
else:
tokens = tokens[: self._max_len]
mask = [1] * self._max_len
batch_tokens.append(tokens)
batch_masks.append(mask)
return np.array(batch_tokens), np.array(batch_masks)

View File

@@ -0,0 +1,9 @@
from openpi.models import tokenizer as _tokenizer
def test_tokenize():
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
tokens, masks = tokenizer.tokenize(["Hello, world!", "This is a test"])
assert tokens.shape == (2, 10)
assert masks.shape == (2, 10)

View File

@@ -0,0 +1,440 @@
from collections.abc import Callable
import enum
import functools as ft
import logging
from typing import Literal
import einops
from flax import struct
import flax.linen as nn
from flax.linen import dtypes
import jax.ad_checkpoint
import jax.numpy as jnp
import openpi.shared.array_typing as at
logger = logging.getLogger(__name__)
AFTER_ATTN_CHECKPOINT_NAME = "after_attn"
AFTER_XATTN_CHECKPOINT_NAME = "after_xattn"
QKV_CHECKPOINT_NAME = "qkv"
def _custom_dot_product_attention(
query,
key,
value,
bias=None,
mask=None,
broadcast_dropout: bool = True, # noqa
dropout_rng=None,
dropout_rate: float = 0.0,
deterministic: bool = False, # noqa
dtype=None,
precision=None,
module=None,
):
"""Mostly copied from nn.dot_product_attention, except for adding checkpointing logic, and enforcing float32 logits
for stability.
"""
assert module is None
assert dropout_rate == 0.0
assert dropout_rng is None
assert bias is None
query, key, value = dtypes.promote_dtype(query, key, value, dtype=dtype)
# save post-projection query, key, value for backward pass
query = jax.ad_checkpoint.checkpoint_name(query, QKV_CHECKPOINT_NAME)
key = jax.ad_checkpoint.checkpoint_name(key, QKV_CHECKPOINT_NAME)
value = jax.ad_checkpoint.checkpoint_name(value, QKV_CHECKPOINT_NAME)
dtype = query.dtype
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
# calculate attention matrix
depth = query.shape[-1]
query = query / jnp.sqrt(depth).astype(dtype)
assert query.dtype == dtype
# calculate logits in float32 for stability
logits = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision, preferred_element_type=jnp.float32)
# apply attention mask
if mask is not None:
big_neg = jnp.finfo(jnp.float32).min
logits = jnp.where(mask, logits, big_neg)
# normalize the attention weights and cast back to the original dtype (if not float32)
attn_weights = jax.nn.softmax(logits).astype(dtype)
# return weighted sum over values for each query position
out = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision)
assert out.dtype == dtype
return out
@at.typecheck
@struct.dataclass
class TokenSequence:
"""Holds a sequence of tokens alongside positional information."""
tokens: at.Float[at.ArrayLike, "b seq emb"]
# pos may or may not have a batch dimension
pos: at.Float[at.Array, "b seq emb"] | at.Float[at.Array, "seq emb"]
# optional masking information
mask: at.Bool[at.Array, "b seq"] | None = None
def __len__(self):
return self.tokens.shape[1]
@property
def emb_dim(self):
return self.tokens.shape[-1]
@staticmethod
def concatenate(*sequences: "TokenSequence") -> "TokenSequence":
"""Concatenates multiple sequences along the sequence dimension."""
tokens = jnp.concatenate([seq.tokens for seq in sequences], axis=1)
# if any sequence's pos has a batch dimension, broadcast the others to have one
if any(seq.pos.ndim == 3 for seq in sequences):
batch_size = next(seq.pos.shape[0] for seq in sequences if seq.pos.ndim == 3)
pos = jnp.concatenate(
[
seq.pos if seq.pos.ndim == 3 else jnp.broadcast_to(seq.pos, (batch_size, *seq.pos.shape))
for seq in sequences
],
axis=1,
)
else:
pos = jnp.concatenate([seq.pos for seq in sequences], axis=0)
# if any sequence has a mask, create True masks for the others
if any(seq.mask is not None for seq in sequences):
mask = jnp.concatenate(
[
seq.mask
if seq.mask is not None
else jnp.ones((seq.tokens.shape[0], seq.tokens.shape[1]), dtype=jnp.bool_)
for seq in sequences
],
axis=1,
)
else:
mask = None
return TokenSequence(tokens=tokens, pos=pos, mask=mask)
class PosembStrategy(enum.Enum):
"""Controls how positional embeddings are incorporated into the transformer. Configured separately for the
input sequence and the cross-attention sequence. Note that for cross-attention, ADD_AT_ATTN and
ADD_AT_BEGINNING are very similar, since the key and value token sequences are the same for every
attention operation. The only difference is that ADD_AT_ATTN adds the positional embeddings to the key
sequence only, while ADD_AT_BEGINNING adds them to both the key and value sequences.
NONE:
Ignore positional embeddings.
ADD_AT_BEGINNING:
Adds the positional embeddings to the token sequence at the beginning of the transformer call.
ADD_AT_ATTN:
Adds the positional embeddings to the query and key (but not value) sequences at every attention
operation.
"""
NONE = enum.auto()
ADD_AT_BEGINNING = enum.auto()
ADD_AT_ATTN = enum.auto()
class AttentionBlock(nn.Module):
"""Implements either self-attention (if q == kv) or cross-attention (if q != kv)."""
num_heads: int
normalize_qk: bool = True
@nn.compact
@at.typecheck
def __call__(
self,
*,
q: at.Float[at.Array, "b q_seq q_emb"],
kv: at.Float[at.Array, "b kv_seq kv_emb"],
q_pos: at.Float[at.Array, "*bq q_seq q_emb"],
kv_pos: at.Float[at.Array, "*bkv kv_seq kv_emb"],
mask: at.Bool[at.Array, "b q_seq kv_seq"] | None = None,
dtype: at.DTypeLike,
) -> at.Float[at.Array, "b q_seq q_emb"]:
# broadcast mask to have a head dimension
if mask is not None:
mask = einops.repeat(mask, "b q_seq kv_seq -> b n q_seq kv_seq", n=self.num_heads)
# we add posembs to queries and keys, but not values
q = q + q_pos
k = kv + kv_pos
v = kv
return nn.MultiHeadAttention(
num_heads=self.num_heads,
normalize_qk=self.normalize_qk,
use_bias=False,
kernel_init=nn.initializers.lecun_normal(),
attention_fn=_custom_dot_product_attention,
dtype=dtype,
)(q, k, v, mask=mask)
class MLPBlock(nn.Module):
dim: int
@nn.compact
@at.typecheck
def __call__(self, x: at.Float[at.Array, "b seq emb"], *, dtype: at.DTypeLike) -> at.Float[at.Array, "b seq emb"]:
embed_dim = x.shape[-1]
# SwiGLU MLP.
# fuse the first 2 matmuls into one in case it's more efficient
out = nn.DenseGeneral((2, self.dim), use_bias=False, kernel_init=nn.initializers.lecun_normal(), dtype=dtype)(x)
gating, hidden = einops.rearrange(out, "b seq n emb -> n b seq emb")
return nn.Dense(embed_dim, use_bias=False, kernel_init=nn.initializers.lecun_normal(), dtype=dtype)(
nn.swish(gating) * hidden
)
class AdaLNGeneral(nn.Module):
"""Generalized LayerNorm module, optionally adaptive based on conditioning information.
If `cond` is None, applies standard LayerNorm with learned scale and bias. If `cond` is not None, applies
adaptive LayerNorm (AdaLN):
>>> out = LayerNorm(x) * (1 + scale) + shift
Where `scale` and `shift` are learned from conditioning information and initialized to always be 0 (so
that the output is initially equal to LayerNorm(x)), and LayerNorm here is the version without learned
parameters.
If `fn` is not None, this module applies normalization, `fn`, and then a residual connection. For example,
with `cond == None`:
>>> out = x + fn(LayerNorm(x))
With `cond != None`, this becomes AdaLNZero (from the DiT paper):
>>> out = x + gate * fn(LayerNorm(x) * (1 + scale) + shift)
where `gate`, `scale`, and `shift` are once again initialized to always be 0, so the output is initially
equal to the input.
"""
fn: Callable | None = None
@nn.compact
@at.typecheck
def __call__(
self,
x: at.Float[at.Array, "b seq emb"],
cond: at.Float[at.Array, "b cond_emb"] | at.Float[at.Array, "b seq cond_emb"] | None = None,
*,
dtype: at.DTypeLike,
) -> at.Float[at.Array, "b seq emb"]:
if cond is None:
if self.fn is None:
return nn.LayerNorm(dtype=dtype)(x)
return x + self.fn(nn.LayerNorm(dtype=dtype)(x))
# number of learned AdaLN vectors
n = 2 if self.fn is None else 3
adaln = nn.DenseGeneral(
features=(n, x.shape[-1]),
kernel_init=nn.zeros,
dtype=dtype,
)(nn.swish(cond))
if cond.ndim == 2:
adaln = einops.rearrange(adaln, "b n emb -> n b 1 emb")
elif cond.ndim == 3:
adaln = einops.rearrange(adaln, "b seq n emb -> n b seq emb")
else:
raise ValueError(f"Invalid number of dimensions for cond: {cond.ndim}")
modulated = nn.LayerNorm(use_bias=False, use_scale=False, dtype=dtype)(x) * (1 + adaln[0]) + adaln[1]
if self.fn is None:
return modulated
return x + adaln[2] * self.fn(modulated)
class TransformerBlock(nn.Module):
"""Transformer block (no attention mask) with optional AdaLN and cross-attention conditioning."""
attn: AttentionBlock
mlp: MLPBlock
@nn.compact
@at.typecheck
def __call__(
self,
x: TokenSequence,
xattn_cond: TokenSequence | None = None,
adaln_cond: at.Float[at.Array, "b adaln_emb"] | at.Float[at.Array, "b seq adaln_emb"] | None = None,
self_attn_mask: at.Bool[at.Array, "b seq seq"] | None = None,
*,
dtype: at.DTypeLike,
) -> TokenSequence:
# if x.mask is not None, apply it to the self-attention mask
if x.mask is not None:
if self_attn_mask is None:
self_attn_mask = jnp.ones((x.tokens.shape[0], x.tokens.shape[1], x.tokens.shape[1]), dtype=jnp.bool_)
# take the outer product of x.mask with itself to form a full (b, seq, seq) attention mask and then combine
# it with the existing attention mask
self_attn_mask = jnp.logical_and(self_attn_mask, jnp.logical_and(x.mask[:, None, :], x.mask[:, :, None]))
def self_attn_fn(y):
return self.attn.copy(name="self_attn")(
q=y, kv=y, q_pos=x.pos, kv_pos=x.pos, mask=self_attn_mask, dtype=dtype
)
# self-attention
tokens = AdaLNGeneral(self_attn_fn)(x.tokens, adaln_cond, dtype=dtype)
tokens = jax.ad_checkpoint.checkpoint_name(tokens, AFTER_ATTN_CHECKPOINT_NAME)
# cross-attention
if xattn_cond is not None:
# if xattn_cond.mask is not None, generate a cross-attention mask
if xattn_cond.mask is not None:
xattn_mask = einops.repeat(xattn_cond.mask, "b xseq -> b seq xseq", seq=x.tokens.shape[1])
else:
xattn_mask = None
def xattn_fn(y):
return self.attn.copy(name="cross_attn")(
q=y, kv=xattn_cond.tokens, q_pos=x.pos, kv_pos=xattn_cond.pos, mask=xattn_mask, dtype=dtype
)
tokens = AdaLNGeneral(xattn_fn)(tokens, adaln_cond, dtype=dtype)
tokens = jax.ad_checkpoint.checkpoint_name(tokens, AFTER_XATTN_CHECKPOINT_NAME)
# mlp
tokens = AdaLNGeneral(ft.partial(self.mlp, dtype=dtype))(tokens, adaln_cond, dtype=dtype)
return x.replace(tokens=tokens)
class Transformer(nn.Module):
"""Transformer stack with optional AdaLN and cross-attention conditioning.
AdaLN conditioning is a single vector. Cross-attention conditioning is a sequence of vectors, where the
sequence length may be different from the input sequence length. The input, adaln conditioning, and cross-
attention conditioning may all have different embedding dimensions.
"""
num_layers: int
transformer_block: TransformerBlock
self_attn_posemb_strategy: PosembStrategy = PosembStrategy.ADD_AT_BEGINNING
xattn_posemb_strategy: PosembStrategy = PosembStrategy.NONE
dtype: str = "bfloat16"
@nn.compact
@at.typecheck
def __call__(
self,
x: TokenSequence,
xattn_cond: TokenSequence | None = None,
adaln_cond: at.Float[at.Array, "b adaln_emb"] | at.Float[at.Array, "b seq adaln_emb"] | None = None,
self_attn_mask: at.Bool[at.Array, "b seq seq"] | None = None,
) -> TokenSequence:
orig_pos = x.pos # save because we always want to include it in the output sequence
# the transformer block always adds positional embeddings, so we disable ADD_AT_ATTN by zeroing them
# out here
if self.self_attn_posemb_strategy == PosembStrategy.ADD_AT_BEGINNING:
x = x.replace(tokens=x.tokens + x.pos)
if self.self_attn_posemb_strategy != PosembStrategy.ADD_AT_ATTN:
x = x.replace(pos=jnp.zeros_like(x.pos, dtype=self.dtype))
x = x.replace(tokens=x.tokens.astype(self.dtype))
if xattn_cond is not None:
if self.xattn_posemb_strategy == PosembStrategy.ADD_AT_BEGINNING:
xattn_cond = xattn_cond.replace(tokens=xattn_cond.tokens + xattn_cond.pos)
if self.xattn_posemb_strategy != PosembStrategy.ADD_AT_ATTN:
xattn_cond = xattn_cond.replace(pos=jnp.zeros_like(xattn_cond.pos, dtype=self.dtype))
xattn_cond = xattn_cond.replace(tokens=xattn_cond.tokens.astype(self.dtype))
if adaln_cond is not None:
adaln_cond = adaln_cond.astype(self.dtype)
def block_call(module, x):
return module(x, xattn_cond, adaln_cond, self_attn_mask, dtype=self.dtype), None
# Enables rematerialization (aka gradient checkpointing). This configuration saves only the post-projection
# query, key, and value tensors, as well as the activations after the full attention and cross-attention blocks.
# This is based on seqax.
block_call_remat = nn.remat(
block_call,
policy=jax.checkpoint_policies.save_only_these_names(
(AFTER_ATTN_CHECKPOINT_NAME, AFTER_XATTN_CHECKPOINT_NAME, QKV_CHECKPOINT_NAME)
),
)
# scanning over layers significantly speeds up compilation time
x, _ = nn.scan(
block_call_remat,
length=self.num_layers,
variable_axes={"params": 0}, # create new parameters for each iteration
split_rngs={"params": True},
)(self.transformer_block, x)
x = x.replace(tokens=AdaLNGeneral(name="final_norm")(x.tokens, adaln_cond, dtype=self.dtype))
# restore original posemb for downstream use
return x.replace(pos=orig_pos)
Variant = Literal["dummy", "tiny", "small", "base", "large"]
def get_variant(variant: Variant, **kwargs) -> tuple[Transformer, int]:
if variant == "dummy":
return Transformer(
num_layers=2,
transformer_block=TransformerBlock(
attn=AttentionBlock(num_heads=2),
mlp=MLPBlock(dim=4),
),
**kwargs,
), 4
if variant == "tiny":
return Transformer(
num_layers=4,
transformer_block=TransformerBlock(
attn=AttentionBlock(num_heads=2),
mlp=MLPBlock(dim=512),
),
**kwargs,
), 128
if variant == "small":
return Transformer(
num_layers=12,
transformer_block=TransformerBlock(
attn=AttentionBlock(num_heads=6),
mlp=MLPBlock(dim=1536),
),
**kwargs,
), 384
if variant == "base":
return Transformer(
num_layers=12,
transformer_block=TransformerBlock(
attn=AttentionBlock(num_heads=12),
mlp=MLPBlock(dim=3072),
),
**kwargs,
), 768
if variant == "large":
return Transformer(
num_layers=24,
transformer_block=TransformerBlock(
attn=AttentionBlock(num_heads=16),
mlp=MLPBlock(dim=4096),
),
**kwargs,
), 1024
raise ValueError(f"Invalid transformer variant: {variant}")

307
src/openpi/models/vit.py Normal file
View File

@@ -0,0 +1,307 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
from collections.abc import Callable
from typing import Any
import flax.linen as nn
import jax
import jax.numpy as jnp
from openpi.models import resnet as models_resnet
Array = Any
PRNGKey = Any
Shape = tuple[int]
Dtype = Any
class IdentityLayer(nn.Module):
"""Identity layer, convenient for giving a name to an array."""
@nn.compact
def __call__(self, x):
return x
class AddPositionEmbs(nn.Module):
"""Adds learned positional embeddings to the inputs.
Attributes:
posemb_init: positional embedding initializer.
"""
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
param_dtype: Dtype = jnp.float32
@nn.compact
def __call__(self, inputs):
"""Applies the AddPositionEmbs module.
Args:
inputs: Inputs to the layer.
Returns:
Output tensor with shape `(bs, timesteps, in_dim)`.
"""
# inputs.shape is (batch_size, seq_len, emb_dim).
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
return inputs + pe
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
out_dim: int | None = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
inputs
)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
)( # pytype: disable=wrong-arg-types
x
)
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
attention_dropout_rate: dropout for attention heads.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, deterministic):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads,
# why isn't this true by default???
force_fp32_for_softmax=True,
)(x, x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
y, deterministic=deterministic
)
return x + y, None
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
attention_dropout_rate: dropout rate in self attention.
"""
dtype: jax.typing.DTypeLike
num_layers: int
mlp_dim: int
num_heads: int
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
add_position_embedding: bool = True
@nn.compact
def __call__(self, x, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, len, emb)
if self.add_position_embedding:
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name="posembed_input",
)(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
x = x.astype(self.dtype)
# Input Encoder
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
x, _ = nn.scan(
block,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=nn.broadcast,
length=self.num_layers,
)(
name="encoderblock",
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
dtype=self.dtype,
num_heads=self.num_heads,
)(x, not train)
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
class VisionTransformer(nn.Module):
"""VisionTransformer."""
dtype: jax.typing.DTypeLike
num_classes: int
patches: Any
transformer: Any
hidden_size: int
resnet: Any | None = None
representation_size: int | None = None
classifier: str = "token"
head_bias_init: float = 0.0
encoder: type[nn.Module] = Encoder
model_name: str | None = None
@nn.compact
def __call__(self, inputs, *, train):
x = inputs
# (Possibly partial) ResNet root.
if self.resnet is not None:
width = int(64 * self.resnet.width_factor)
# Root block.
x = models_resnet.StdConv(
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
)(x)
x = nn.GroupNorm(name="gn_root")(x)
x = nn.relu(x)
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
# ResNet stages.
if self.resnet.num_layers:
x = models_resnet.ResNetStage(
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
)(x)
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
x = models_resnet.ResNetStage(
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
)(x)
n, h, w, c = x.shape
# We can merge s2d+emb into a single conv; it's the same.
x = nn.Conv(
features=self.hidden_size,
kernel_size=self.patches.size,
strides=self.patches.size,
padding="VALID",
name="embedding",
)(x)
# Here, x is a grid of embeddings.
# (Possibly partial) Transformer.
if self.transformer is not None:
n, h, w, c = x.shape
x = jnp.reshape(x, [n, h * w, c])
# If we want to add a class token, add it here.
if self.classifier in ["token", "token_unpooled"]:
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
cls = jnp.tile(cls, [n, 1, 1])
x = jnp.concatenate([cls, x], axis=1)
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
if self.classifier == "token":
x = x[:, 0]
elif self.classifier == "gap":
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
elif self.classifier in ["unpooled", "token_unpooled"]:
pass
else:
raise ValueError(f"Invalid classifier={self.classifier}")
if self.representation_size is not None:
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
x = nn.tanh(x)
else:
x = IdentityLayer(name="pre_logits")(x)
if self.num_classes:
x = nn.Dense(
features=self.num_classes,
name="head",
kernel_init=nn.initializers.zeros,
bias_init=nn.initializers.constant(self.head_bias_init),
)(x)
return x

View File

@@ -0,0 +1,253 @@
from collections.abc import Sequence
import einops
import numpy as np
from openpi import transforms
def make_aloha_example() -> dict:
return {
"qpos": np.ones((14,)),
"image": np.random.rand(4, 3, 480, 640).astype(np.float32),
}
class ActInputsRepack(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# images is [..., num_cams, channel, height, width] of type uint8.
# number of cameras (num_cams) depends on the environment.
images = np.asarray(data["image"])
num_cams = images.shape[-4]
if num_cams == 4:
cam_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
elif num_cams == 1:
cam_names = ["cam_high"]
else:
raise ValueError(f"Expected 1 or 4 cameras, got {num_cams}")
# `images` have shape [..., cam_idx, channel, height, width].
image_splits = [np.squeeze(x, axis=-4) for x in np.split(images, num_cams, axis=-4)]
images_dict = dict(zip(cam_names, image_splits, strict=True))
return {
"images": images_dict,
"state": data["qpos"],
}
class ActOutputsRepack(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
return {"qpos": data["actions"]}
class AlohaInputs(transforms.DataTransformFn):
"""Inputs for the Aloha policy.
Expected inputs:
- images: dict[name, img] where img is [..., channel, height, width]. name must be in EXPECTED_CAMERAS.
- state: [..., 14]
- actions: [..., action_horizon, action_dim]
Args:
action_dim: The dimension of the action space.
delta_action_mask: A boolean mask for the action dimensions. If None, absolute actions are used.
adapt_to_pi: If true, will adapt the joint and gripper values to match the pi runtime.
"""
EXPECTED_CAMERAS = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None, adapt_to_pi: bool = False):
self._action_dim = action_dim
self._delta_action_mask = delta_action_mask
self._adapt_to_pi = adapt_to_pi
def __call__(self, data: dict) -> dict:
data = _decode_aloha(data, adapt_to_pi=self._adapt_to_pi)
# Get the state. We are padding from 14 to the model action dim.
state = transforms.pad_to_dim(data["state"], self._action_dim)
in_images = data["images"]
if set(in_images) - set(self.EXPECTED_CAMERAS):
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
# Assume that base image always exists.
base_image = in_images["cam_high"]
batch_size = base_image.shape[:-3]
images = {
"base_0_rgb": base_image,
}
image_masks = {
"base_0_rgb": np.ones(batch_size, dtype=np.bool_),
}
# Add the extra images.
extra_image_names = {
"left_wrist_0_rgb": "cam_left_wrist",
"right_wrist_0_rgb": "cam_right_wrist",
}
for dest, source in extra_image_names.items():
if source in in_images:
images[dest] = in_images[source]
image_masks[dest] = np.ones(batch_size, dtype=np.bool_)
else:
images[dest] = np.zeros_like(base_image)
image_masks[dest] = np.zeros(batch_size, dtype=np.bool_)
inputs = {
"image": images,
"image_mask": image_masks,
"state": state,
}
# Actions are only available during training.
if "actions" in data:
actions = np.asarray(data["actions"])
actions = _encode_actions_inv(actions, adapt_to_pi=self._adapt_to_pi)
if self._delta_action_mask is not None:
mask = np.asarray(self._delta_action_mask[:14])
actions = actions - np.expand_dims(np.where(mask, state[..., :14], 0), axis=-2)
inputs["actions"] = transforms.pad_to_dim(actions, self._action_dim)
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
class AlohaOutputs(transforms.DataTransformFn):
"""Outputs for the Aloha policy.
Args:
delta_action_mask: A boolean mask for the action dimensions. If None, absolute actions are used.
adapt_to_pi: If true, will adapt the joint and gripper values to match the pi runtime.
"""
def __init__(self, *, delta_action_mask: Sequence[bool] | None = None, adapt_to_pi: bool = False):
self._delta_action_mask = delta_action_mask
self._adapt_to_pi = adapt_to_pi
def __call__(self, data: dict) -> dict:
# Only return the first 14 dims.
actions = np.asarray(data["actions"][..., :14])
# Apply the delta action mask.
if self._delta_action_mask is not None:
state = np.asarray(data["state"][..., :14])
mask = np.asarray(self._delta_action_mask[:14])
actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2)
return {"actions": _encode_actions(actions, adapt_to_pi=self._adapt_to_pi)}
def joint_flip_mask() -> np.ndarray:
"""Used to convert between aloha and pi joint angles."""
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return np.arcsin(np.clip(value, -1.0, 1.0))
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
# state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
# dim sizes: [6, 1, 6, 1]
state = np.asarray(data["state"])
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
def convert_image(img):
img = np.asarray(img)
# Convert to uint8 if using float images.
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
# Convert from [..., channel, height, width] to [..., height, width, channel].
return einops.rearrange(img, "... c h w -> ... h w c")
images = data["images"]
images_dict = {name: convert_image(img) for name, img in images.items()}
data["images"] = images_dict
data["state"] = state
return data
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
state = joint_flip_mask() * state
# Reverse the gripper transformation that is being applied by the Aloha runtime.
state[..., 6] = gripper_to_angular(state[..., 6])
state[..., 13] = gripper_to_angular(state[..., 13])
return state
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
# Flip the joints.
actions = joint_flip_mask() * actions
actions[..., 6] = gripper_from_angular(actions[..., 6])
actions[..., 13] = gripper_from_angular(actions[..., 13])
return actions
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
if adapt_to_pi:
actions = joint_flip_mask() * actions
actions[..., 6] = gripper_from_angular_inv(actions[..., 6])
actions[..., 13] = gripper_from_angular_inv(actions[..., 13])
return actions

View File

@@ -0,0 +1,38 @@
import jax.numpy as jnp
from openpi import transforms
class CalvinInputs(transforms.DataTransformFn):
def __init__(self, action_dim: int):
self._action_dim = action_dim
def __call__(self, data: dict) -> dict:
state = transforms.pad_to_dim(data["observation/state"], self._action_dim)
inputs = {
"state": state,
"image": {
"rgb_static": data["observation/rgb_static"],
"rgb_gripper": data["observation/rgb_gripper"],
},
"image_mask": {
"rgb_static": jnp.ones(1, dtype=jnp.bool_),
"rgb_gripper": jnp.ones(1, dtype=jnp.bool_),
},
}
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
class CalvinOutputs(transforms.DataTransformFn):
def __init__(self):
pass
def __call__(self, data: dict) -> dict:
# Only return the first 15 dims.
actions = jnp.asarray(data["actions"][..., :15])
return {"actions": actions}

View File

@@ -0,0 +1,53 @@
from collections.abc import Sequence
import numpy as np
from openpi import transforms
class DroidInputs(transforms.DataTransformFn):
def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None):
self._action_dim = action_dim
self._delta_action_mask = delta_action_mask
def __call__(self, data: dict) -> dict:
state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]], axis=1)
state = transforms.pad_to_dim(state, self._action_dim)
base_image = data["observation/exterior_image_1_left"]
inputs = {
"state": state,
"image": {
"base_0_rgb": data["observation/exterior_image_1_left"],
"left_wrist_0_rgb": data["observation/wrist_image_left"],
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.ones(1, dtype=np.bool_),
"left_wrist_0_rgb": np.ones(1, dtype=np.bool_),
"right_wrist_0_rgb": np.zeros(1, dtype=np.bool_),
},
}
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
class DroidOutputs(transforms.DataTransformFn):
def __init__(self, *, delta_action_mask: Sequence[bool] | None = None):
self._delta_action_mask = delta_action_mask
def __call__(self, data: dict) -> dict:
# Only return the first 8 dims.
actions = np.asarray(data["actions"][..., :8])
# Apply the delta action mask.
if self._delta_action_mask is not None:
state = np.asarray(data["state"][..., :8])
mask = np.asarray(self._delta_action_mask[:8])
actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2)
return {"actions": actions}

View File

@@ -0,0 +1,35 @@
import jax.numpy as jnp
from openpi import transforms
class LiberoInputs(transforms.DataTransformFn):
def __init__(self, action_dim: int):
self._action_dim = action_dim
def __call__(self, data: dict) -> dict:
state = transforms.pad_to_dim(data["observation/state"], self._action_dim)
inputs = {
"state": state,
"image": {
"image": data["observation/image"],
"wrist_image": data["observation/wrist_image"],
},
"image_mask": {
"image": jnp.ones(1, dtype=jnp.bool_),
"wrist_image": jnp.ones(1, dtype=jnp.bool_),
},
}
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
class LiberoOutputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Only return the first 8 dims.
actions = jnp.asarray(data["actions"][..., :8])
return {"actions": actions}

View File

@@ -0,0 +1,87 @@
from collections.abc import Sequence
import logging
import pathlib
from typing import Any, TypeAlias
import flax
import flax.traverse_util
import jax
import jax.numpy as jnp
import numpy as np
from openpi_client import base_policy as _base_policy
from typing_extensions import override
from openpi import transforms as _transforms
from openpi.models import common
from openpi.models import model as _model
from openpi.shared import array_typing as at
BasePolicy: TypeAlias = _base_policy.BasePolicy
class Policy(BasePolicy):
def __init__(
self,
model: _model.BaseModel,
*,
rng: at.KeyArrayLike | None = None,
transforms: Sequence[_transforms.DataTransformFn] = (),
output_transforms: Sequence[_transforms.DataTransformFn] = (),
sample_kwargs: dict[str, Any] | None = None,
):
self._model = model
self._input_transform = _transforms.CompositeTransform(transforms)
self._output_transform = _transforms.CompositeTransform(output_transforms)
self._rng = rng or jax.random.key(0)
self._sample_kwargs = sample_kwargs or {"num_steps": 10}
@override
def infer(self, obs: dict) -> dict: # type: ignore[misc]
inputs = self._input_transform(_make_batch(obs))
inputs = jax.tree_util.tree_map(lambda x: jnp.asarray(x), inputs)
self._rng, sample_rng = jax.random.split(self._rng)
outputs = {
"state": inputs["state"],
"actions": self._model.sample_actions(
sample_rng, common.Observation.from_dict(inputs), **self._sample_kwargs
),
}
outputs = self._output_transform(outputs)
return _unbatch(jax.device_get(outputs))
class PolicyRecorder(_base_policy.BasePolicy):
"""Records the policy's behavior to disk."""
def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
self._policy = policy
logging.info(f"Dumping policy records to: {record_dir}")
self._record_dir = pathlib.Path(record_dir)
self._record_dir.mkdir(parents=True, exist_ok=True)
self._record_step = 0
@override
def infer(self, obs: dict) -> dict: # type: ignore[misc]
results = self._policy.infer(obs)
data = {"inputs": obs, "outputs": results}
data = flax.traverse_util.flatten_dict(data, sep="/")
output_path = self._record_dir / f"step_{self._record_step}"
self._record_step += 1
np.save(output_path, np.asarray(data))
return results
def _make_batch(data: at.PyTree[np.ndarray]) -> at.PyTree[np.ndarray]:
def _transform(x: np.ndarray) -> np.ndarray:
return np.asarray(x)[np.newaxis, ...]
return jax.tree_util.tree_map(_transform, data)
def _unbatch(data: at.PyTree[np.ndarray]) -> at.PyTree[np.ndarray]:
return jax.tree_util.tree_map(lambda x: np.asarray(x[0, ...]), data)

View File

@@ -0,0 +1,123 @@
from collections.abc import Sequence
import dataclasses
import logging
import pathlib
from typing import Any
import jax.numpy as jnp
from openpi.models import tokenizer
import openpi.models.model as _model
import openpi.policies.policy as _policy
import openpi.shared.download as download
from openpi.training import checkpoints as _checkpoints
from openpi.training import config as _config
import openpi.transforms as transforms
@dataclasses.dataclass
class PolicyConfig:
model: _model.BaseModel
norm_stats: dict[str, transforms.NormStats]
input_layers: Sequence[transforms.DataTransformFn]
output_layers: Sequence[transforms.DataTransformFn]
default_prompt: str | None = None
sample_kwargs: dict[str, Any] | None = None
def create_policy(config: PolicyConfig) -> _policy.Policy:
"""Creates a default pi0 policy."""
return _policy.Policy(
config.model,
transforms=[
*config.input_layers,
transforms.Normalize(config.norm_stats),
transforms.TokenizePrompt(
tokenizer.PaligemmaTokenizer(config.model.max_token_len), default_prompt=config.default_prompt
),
],
output_transforms=[
transforms.Unnormalize(config.norm_stats),
*config.output_layers,
],
sample_kwargs=config.sample_kwargs,
)
def create_trained_policy(
train_config: _config.TrainConfig,
checkpoint_dir: pathlib.Path | str,
*,
repack_transforms: transforms.Group | None = None,
sample_kwargs: dict[str, Any] | None = None,
default_prompt: str | None = None,
norm_stats: dict[str, transforms.NormStats] | None = None,
) -> _policy.Policy:
"""Create a policy from a trained checkpoint.
Args:
train_config: The training config to use to create the model.
checkpoint_dir: The directory to load the model from.
repack_transforms: Optional transforms that will be applied before any other transforms.
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
kwargs will be used.
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
data if it doesn't already exist.
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
from the checkpoint directory.
"""
repack_transforms = repack_transforms or transforms.Group()
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
logging.info("Loading model...")
model = train_config.create_model()
model = model.set_params(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
data_config = train_config.data.create(train_config.metadata_dir, model)
if norm_stats is None:
# We are loading the norm stats from the checkpoint, instead of the metadata dir to make sure
# that the policy is using the same normalization stats as the original training process.
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets")
return _policy.Policy(
model,
transforms=[
*repack_transforms.inputs,
transforms.InjectDefaultPrompt(default_prompt),
*data_config.data_transforms.inputs,
transforms.Normalize(norm_stats),
*data_config.model_transforms.inputs,
],
output_transforms=[
*data_config.model_transforms.outputs,
transforms.Unnormalize(norm_stats),
*data_config.data_transforms.outputs,
*repack_transforms.outputs,
],
sample_kwargs=sample_kwargs,
)
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
"""Make a boolean mask for the given dimensions.
Example:
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
make_bool_mask(2, 0, 2) == (True, True, True, True)
Args:
dims: The dimensions to make the mask for.
Returns:
A tuple of booleans.
"""
result = []
for dim in dims:
if dim > 0:
result.extend([True] * (dim))
else:
result.extend([False] * (-dim))
return tuple(result)

View File

@@ -0,0 +1,17 @@
from openpi.policies import policy_config as _policy_config
from openpi.training import config as _config
def test_make_bool_mask():
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
def test_create_trained_policy():
policy = _policy_config.create_trained_policy(
_config.get_config("debug"),
"s3://openpi-assets/checkpoints/pi0_base",
# The base checkpoint doesn't have norm stats.
norm_stats={},
)
assert policy is not None

View File

@@ -0,0 +1,55 @@
from openpi_client import action_chunk_broker
from openpi.models import exported as _exported
from openpi.policies import aloha_policy
from openpi.policies import policy_config as _policy_config
def create_policy_config() -> _policy_config.PolicyConfig:
model = _exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
return _policy_config.PolicyConfig(
model=model,
norm_stats=model.norm_stats("huggingface_aloha_sim_transfer_cube"),
input_layers=[
aloha_policy.ActInputsRepack(),
aloha_policy.AlohaInputs(
action_dim=model.action_dim,
delta_action_mask=None,
adapt_to_pi=False,
),
],
output_layers=[
aloha_policy.AlohaOutputs(
delta_action_mask=None,
adapt_to_pi=False,
),
aloha_policy.ActOutputsRepack(),
],
)
def test_infer():
config = create_policy_config()
policy = _policy_config.create_policy(config)
example = aloha_policy.make_aloha_example()
outputs = policy.infer(example)
assert outputs["qpos"].shape == (config.model.action_horizon, 14)
def test_broker():
config = create_policy_config()
policy = _policy_config.create_policy(config)
broker = action_chunk_broker.ActionChunkBroker(
policy,
# Only execute the first half of the chunk.
action_horizon=config.model.action_horizon // 2,
)
example = aloha_policy.make_aloha_example()
for _ in range(config.model.action_horizon):
outputs = broker.infer(example)
assert outputs["qpos"].shape == (14,)

0
src/openpi/py.typed Normal file
View File

View File

@@ -0,0 +1,55 @@
import asyncio
import logging
import traceback
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
import websockets.asyncio.server
import websockets.frames
class WebsocketPolicyServer:
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
Currently only implements the `load` and `infer` methods.
TODO: Implement the other methods.
"""
def __init__(self, policy: _base_policy.BasePolicy, host: str = "0.0.0.0", port: int = 8000) -> None:
self._policy = policy
self._host = host
self._port = port
logging.getLogger("websockets.server").setLevel(logging.INFO)
def serve_forever(self) -> None:
asyncio.run(self.run())
async def run(self):
async with websockets.asyncio.server.serve(
self._handler,
self._host,
self._port,
compression=None,
max_size=None,
) as server:
await server.serve_forever()
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
logging.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_numpy.Packer()
while True:
try:
obs = msgpack_numpy.unpackb(await websocket.recv())
action = self._policy.infer(obs)
await websocket.send(packer.pack(action))
except websockets.ConnectionClosed:
logging.info(f"Connection from {websocket.remote_address} closed")
break
except Exception:
await websocket.send(traceback.format_exc())
await websocket.close(
code=websockets.frames.CloseCode.INTERNAL_ERROR,
reason="Internal server error. Traceback included in previous frame.",
)
raise

Some files were not shown because too many files have changed in this diff Show More