Initial commit
This commit is contained in:
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@@ -0,0 +1,3 @@
|
||||
.venv
|
||||
checkpoints
|
||||
data
|
||||
16
.github/CODEOWNERS
vendored
Normal file
16
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# The CODEOWNERS file defines individuals or teams that are automatically requested for
|
||||
# review when someone opens a pull request that modifies certain code. When a draft pull
|
||||
# request is marked as ready for review, code owners are automatically notified.
|
||||
#
|
||||
# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
#
|
||||
# This is a comment.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
|
||||
# Global owners.
|
||||
* @jimmyt857 @Michael-Equi @uzhilinsky
|
||||
|
||||
src/openpi/models/ @kvablack @uzhilinsky
|
||||
src/openpi/training/ @kvablack @uzhilinsky
|
||||
|
||||
scripts/ @jimmyt857 @kvablack @uzhilinsky
|
||||
17
.github/workflows/pre-commit.yml
vendored
Normal file
17
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
name: pre-commit
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v3
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
26
.github/workflows/test.yml
vendored
Normal file
26
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest src scripts
|
||||
168
.gitignore
vendored
Normal file
168
.gitignore
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
# Data directories.
|
||||
assets/
|
||||
checkpoints/
|
||||
data/
|
||||
wandb/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
9
.gitmodules
vendored
Normal file
9
.gitmodules
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
[submodule "third_party/aloha"]
|
||||
path = third_party/aloha
|
||||
url = git@github.com:Physical-Intelligence/aloha.git
|
||||
[submodule "third_party/calvin"]
|
||||
path = third_party/calvin
|
||||
url = git@github.com:mees/calvin.git
|
||||
[submodule "third_party/libero"]
|
||||
path = third_party/libero
|
||||
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
|
||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
exclude: third_party/
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# uv version.
|
||||
rev: 0.5.9
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.7.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
},
|
||||
"python.testing.pytestArgs": [
|
||||
"src"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
116
README.md
Normal file
116
README.md
Normal file
@@ -0,0 +1,116 @@
|
||||
# openpi
|
||||
|
||||
openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
|
||||
|
||||
Currently, it is focused on the `pi0` model described in [this blog post](https://www.physicalintelligence.company/blog/pi0).
|
||||
|
||||
## Setup
|
||||
|
||||
When cloning this repo, make sure to update submodules:
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
|
||||
|
||||
# Or if you already cloned the repo:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
### Using uv
|
||||
|
||||
We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up.
|
||||
|
||||
Once uv is installed, run the following to set up the environment:
|
||||
|
||||
```bash
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
||||
```
|
||||
|
||||
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
||||
|
||||
### Docker Setup
|
||||
|
||||
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
||||
|
||||
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/install_docker_ubuntu22.sh` and `scripts/install_nvidia_container_toolkit.sh`.
|
||||
|
||||
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
||||
|
||||
### Downloading checkpoints
|
||||
|
||||
By default checkpoints are downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
||||
|
||||
## Running Training
|
||||
|
||||
Training configs are defined in [src/openpi/training/config.py](src/openpi/training/config.py) and the training script is in [scripts/train.py](scripts/train.py).
|
||||
|
||||
Each registered config is available as a command line argument to `scripts/train.py`. To find all available command line arguments for your config, run `uv run scripts/train.py <config-name> --help`, or look at the `TrainConfig` class in [src/openpi/training/config.py](src/openpi/training/config.py).
|
||||
|
||||
|
||||
For example, to train with the `pi0_aloha_sim` config, run the following;
|
||||
|
||||
(one time only) Compute the norm stats for the training data:
|
||||
|
||||
```bash
|
||||
uv run scripts/compute_norm_stats.py --config-name pi0_aloha_sim
|
||||
```
|
||||
|
||||
Run training:
|
||||
|
||||
```bash
|
||||
uv run scripts/train.py pi0_aloha_sim --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
The `pi0_aloha_sim` config is optimized for training on a single H100 GPU. By default, JAX pre-allocates 75% of available GPU memory. We set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` to allow JAX to use up to 90% of GPU memory, which enables training with larger batch sizes while maintaining stability.
|
||||
|
||||
The training script automatically utilizes all available GPUs on a single node. Currently, distributed training across multiple nodes is not supported.
|
||||
|
||||
## Running examples
|
||||
|
||||
We provide example integrations with several robotics platforms. See the README in each example for more details:
|
||||
|
||||
- [ALOHA Sim](examples/aloha_sim)
|
||||
- [ALOHA Real](examples/aloha_real)
|
||||
- [CALVIN](examples/calvin)
|
||||
- [LIBERO](examples/libero)
|
||||
|
||||
## Running the openpi server
|
||||
|
||||
The server can be configured to serve openpi policies in the following ways:
|
||||
|
||||
- Serve a default policy for the given environment.
|
||||
- Serve a trained policy from a checkpoint.
|
||||
- Serve an exported model.
|
||||
|
||||
### Serve the default policy for the LIBERO environment
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env LIBERO --default_prompt "my task"
|
||||
```
|
||||
|
||||
### Serve a trained policy from an openpi checkpoint
|
||||
|
||||
This option allows serving a model that was trained using the openpi training code.
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM policy:checkpoint --policy.config=pi0_aloha_sim --policy.dir=checkpoints/pi0_aloha_sim/exp_name/10000
|
||||
```
|
||||
|
||||
The training config is used to determine which data transformations should be applied to the runtime data before feeding into the model. The norm stats, which are used to normalize the transformed data, are loaded from the checkpoint directory.
|
||||
|
||||
### Serve an exported model
|
||||
|
||||
There are also a number of checkpoints that are available as exported JAX graphs, which we trained ourselves using our internal training code. These can be served using the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA policy:exported --policy.dir=s3://openpi-assets/exported/pi0_aloha/model --policy.processor=trossen_biarm_single_base_cam_24dim
|
||||
```
|
||||
|
||||
In this case, the data transformations are taken from the default policy and the processor name will be used to determine which norms stats should be used to normalize the transformed data.
|
||||
|
||||
|
||||
### Running with Docker:
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM --default_prompt 'my task'"
|
||||
docker compose -f scripts/compose.yml up --build
|
||||
```
|
||||
70
examples/aloha_real/Dockerfile
Normal file
70
examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
73
examples/aloha_real/README.md
Normal file
73
examples/aloha_real/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha).
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='toast out of toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='toast out of toaster'
|
||||
```
|
||||
|
||||
## Model Guide
|
||||
The Pi0 Base Model is an out-of-the-box model for general tasks. You can find more details in the [technical report](https://www.physicalintelligence.company/download/pi0.pdf).
|
||||
|
||||
While we strongly recommend fine-tuning the model to your own data to adapt it to particular tasks, it may be possible to prompt the model to attempt some tasks that were in the pre-training data. For example, below is a video of the model attempting the "toast out of toaster" task.
|
||||
|
||||
<p align="center">
|
||||
<img src="https://github.com/Physical-Intelligence/openpi/blob/main/examples/aloha_real/toast.gif" alt="toast out of toaster"/>
|
||||
</p>
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
OpenPI suppports training on data collected in the default aloha hdf5 format. To do so you must first convert the data to the huggingface format. We include `scripts/aloha_hd5.py` to help you do this. Once the dataset is converted, add a new `TrainConfig` to `src/openpi/training/configs.py` and replace repo id with the id assigned to your dataset during conversion.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name=<your-config-name>,
|
||||
data=LeRobotAlohaDataConfig(
|
||||
repo_id=<your-repo-id>,
|
||||
delta_action_mask=[True] * 6 + [False] + [True] * 6 + [False],
|
||||
),
|
||||
),
|
||||
```
|
||||
|
||||
Run the training script:
|
||||
|
||||
```bash
|
||||
uv run scripts/train.py <your-config-name>
|
||||
```
|
||||
63
examples/aloha_real/compose.yml
Normal file
63
examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,63 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real/constants.py
Normal file
71
examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
52
examples/aloha_real/env.py
Normal file
52
examples/aloha_real/env.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import einops
|
||||
import numpy as np
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(self, render_height: int = 480, render_width: int = 640) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
images = []
|
||||
for cam_name in obs["images"]:
|
||||
curr_image = obs["images"][cam_name]
|
||||
curr_image = einops.rearrange(curr_image, "h w c -> c h w")
|
||||
images.append(curr_image)
|
||||
stacked_images = np.stack(images, axis=0).astype(np.uint8)
|
||||
|
||||
# TODO: Consider removing these transformations.
|
||||
return {
|
||||
"qpos": obs["qpos"],
|
||||
"image": stacked_images,
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["qpos"])
|
||||
42
examples/aloha_real/main.py
Normal file
42
examples/aloha_real/main.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
167
examples/aloha_real/real_env.py
Normal file
167
examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, setup_robots: bool = True):
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
reset_position = [0, -1.5, 1.5, 0, 0, 0]
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [reset_position, reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, setup_robots=setup_robots)
|
||||
18
examples/aloha_real/requirements.in
Normal file
18
examples/aloha_real/requirements.in
Normal file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real/requirements.txt
Normal file
156
examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
BIN
examples/aloha_real/toast.gif
Normal file
BIN
examples/aloha_real/toast.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 23 MiB |
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
examples/aloha_sim/Dockerfile
Normal file
41
examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
examples/aloha_sim/README.md
Normal file
36
examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
39
examples/aloha_sim/compose.yml
Normal file
39
examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
examples/aloha_sim/env.py
Normal file
56
examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def done(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["qpos"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(gym_obs["pixels"]["top"], (2, 0, 1))
|
||||
|
||||
# Add multi-camera dimension, to match the way real aloha provides images as [cam_idx, C, H, W].
|
||||
imgs = np.expand_dims(img, axis=0)
|
||||
|
||||
return {
|
||||
"qpos": gym_obs["agent_pos"],
|
||||
"image": imgs,
|
||||
}
|
||||
55
examples/aloha_sim/main.py
Normal file
55
examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_path: pathlib.Path = pathlib.Path("out.mp4")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_path),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
8
examples/aloha_sim/requirements.in
Normal file
8
examples/aloha_sim/requirements.in
Normal file
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
examples/aloha_sim/requirements.txt
Normal file
132
examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
35
examples/aloha_sim/saver.py
Normal file
35
examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_path: pathlib.Path, subsample: int = 1) -> None:
|
||||
self._out_path = out_path
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
logging.info(f"Saving video to {self._out_path}")
|
||||
imageio.mimwrite(
|
||||
self._out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
65
examples/calvin/Dockerfile
Normal file
65
examples/calvin/Dockerfile
Normal file
@@ -0,0 +1,65 @@
|
||||
# THIS DOCKERFILE DOES NOT YET WORK
|
||||
# Dockerfile for the CALVIN benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t calvin -f examples/calvin/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --privileged --gpus all calvin /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
make \
|
||||
g++ \
|
||||
git \
|
||||
wget \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6 \
|
||||
unzip \
|
||||
ffmpeg
|
||||
|
||||
# Install miniconda
|
||||
ENV CONDA_DIR=/opt/conda
|
||||
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
|
||||
/bin/bash ~/miniconda.sh -b -p $CONDA_DIR
|
||||
ENV PATH=$CONDA_DIR/bin:$PATH
|
||||
|
||||
# Submodules don't work with calvin because it internally parses git metadata.
|
||||
# So we have to clone it directly.
|
||||
RUN git clone --recurse-submodules https://github.com/mees/calvin.git /root/calvin
|
||||
|
||||
RUN conda create -n calvin python=3.8
|
||||
RUN source /opt/conda/bin/activate calvin && \
|
||||
pip install setuptools==57.5.0 && \
|
||||
cd /root/calvin && \
|
||||
./install.sh && \
|
||||
pip install \
|
||||
imageio[ffmpeg] \
|
||||
moviepy \
|
||||
numpy==1.23.0 \
|
||||
tqdm \
|
||||
tyro \
|
||||
websockets \
|
||||
msgpack
|
||||
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src
|
||||
|
||||
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
|
||||
RUN mkdir -p /datasets && cd /datasets && \
|
||||
wget http://calvin.cs.uni-freiburg.de/dataset/calvin_debug_dataset.zip && \
|
||||
unzip calvin_debug_dataset.zip && \
|
||||
rm calvin_debug_dataset.zip
|
||||
|
||||
WORKDIR /app
|
||||
CMD ["/bin/bash", "-c", "source /opt/conda/bin/activate calvin && python examples/calvin/main.py"]
|
||||
47
examples/calvin/README.md
Normal file
47
examples/calvin/README.md
Normal file
@@ -0,0 +1,47 @@
|
||||
# CALVIN Benchmark
|
||||
|
||||
This example runs the CALVIN benchmark: https://github.com/mees/calvin
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env CALVIN"
|
||||
docker compose -f examples/calvin/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
cd $OPENPI_ROOT
|
||||
conda create -n calvin python=3.8
|
||||
conda activate calvin
|
||||
|
||||
git clone --recurse-submodules https://github.com/mees/calvin.git
|
||||
cd calvin
|
||||
pip install setuptools==57.5.0
|
||||
./install.sh
|
||||
|
||||
pip install imageio[ffmpeg] moviepy numpy==1.23.0 tqdm tyro websockets msgpack
|
||||
ENV PYTHONPATH=$PYTHONPATH:$OPENPI_ROOT/packages/openpi-client/src
|
||||
|
||||
# Download CALVIN dataset, see https://github.com/mees/calvin/blob/main/dataset/download_data.sh
|
||||
export CALVIN_DATASETS_DIR=~/datasets
|
||||
export CALVIN_DATASET=calvin_debug_dataset
|
||||
mkdir -p $CALVIN_DATASETS_DIR && cd $CALVIN_DATASETS_DIR
|
||||
wget http://calvin.cs.uni-freiburg.de/dataset/$CALVIN_DATASET.zip
|
||||
unzip $CALVIN_DATASET.zip
|
||||
rm $CALVIN_DATASET.zip
|
||||
|
||||
# Run the simulation
|
||||
cd $OPENPI_ROOT
|
||||
python examples/calvin/main.py --args.calvin_data_path=$CALVIN_DATASETS_DIR
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env CALVIN
|
||||
```
|
||||
46
examples/calvin/compose.yml
Normal file
46
examples/calvin/compose.yml
Normal file
@@ -0,0 +1,46 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/calvin/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: calvin
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/calvin/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
175
examples/calvin/main.py
Normal file
175
examples/calvin/main.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Runs a model in a CALVIN simulation environment."""
|
||||
|
||||
import collections
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
from calvin_agent.evaluation.multistep_sequences import get_sequences
|
||||
from calvin_agent.evaluation.utils import get_env_state_for_initial_condition
|
||||
import calvin_env
|
||||
from calvin_env.envs.play_table_env import get_env
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
DATE_TIME = time.strftime("%Y_%m_%d-%H_%M_%S")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# CALVIN environment-specific parameters
|
||||
#################################################################################################################
|
||||
calvin_data_path: str = "/datasets/calvin_debug_dataset" # Path to CALVIN dataset for loading validation tasks
|
||||
max_subtask_steps: int = 360 # Max number of steps per subtask
|
||||
num_trials: int = 1000 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/calvin/videos" # Path to save videos
|
||||
num_save_videos: int = 5 # Number of videos to be logged per task
|
||||
video_temp_subsample: int = 5 # Temporal subsampling to make videos shorter
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize CALVIN environment
|
||||
env = get_env(pathlib.Path(args.calvin_data_path) / "validation", show_gui=False)
|
||||
|
||||
# Get CALVIN eval task set
|
||||
task_definitions, task_instructions, task_reward = _get_calvin_tasks_and_reward(args.num_trials)
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation.
|
||||
episode_solved_subtasks = []
|
||||
per_subtask_success = collections.defaultdict(list)
|
||||
for i, (initial_state, task_sequence) in enumerate(tqdm.tqdm(task_definitions)):
|
||||
logging.info(f"Starting episode {i+1}...")
|
||||
logging.info(f"Task sequence: {task_sequence}")
|
||||
|
||||
# Reset env to initial position for task
|
||||
robot_obs, scene_obs = get_env_state_for_initial_condition(initial_state)
|
||||
env.reset(robot_obs=robot_obs, scene_obs=scene_obs)
|
||||
|
||||
rollout_images = []
|
||||
solved_subtasks = 0
|
||||
for subtask in task_sequence:
|
||||
start_info = env.get_info()
|
||||
action_plan = collections.deque()
|
||||
|
||||
obs = env.get_obs()
|
||||
done = False
|
||||
for _ in range(args.max_subtask_steps):
|
||||
img = obs["rgb_obs"]["rgb_static"]
|
||||
wrist_img = obs["rgb_obs"]["rgb_gripper"]
|
||||
rollout_images.append(img.transpose(2, 0, 1))
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/rgb_static": img,
|
||||
"observation/rgb_gripper": wrist_img,
|
||||
"observation/state": obs["robot_obs"],
|
||||
"prompt": str(task_instructions[subtask][0]),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Round gripper action since env expects gripper_action in (-1, 1)
|
||||
action[-1] = 1 if action[-1] > 0 else -1
|
||||
|
||||
# Step environment
|
||||
obs, _, _, current_info = env.step(action)
|
||||
|
||||
# check if current step solves a task
|
||||
current_task_info = task_reward.get_task_info_for_set(start_info, current_info, {subtask})
|
||||
if len(current_task_info) > 0:
|
||||
done = True
|
||||
solved_subtasks += 1
|
||||
break
|
||||
|
||||
per_subtask_success[subtask].append(int(done))
|
||||
if not done:
|
||||
# Subtask execution failed --> stop episode
|
||||
break
|
||||
|
||||
episode_solved_subtasks.append(solved_subtasks)
|
||||
if len(episode_solved_subtasks) < args.num_save_videos:
|
||||
# Save rollout video.
|
||||
idx = len(episode_solved_subtasks)
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{idx}.mp4",
|
||||
[np.asarray(x) for x in rollout_images[:: args.video_temp_subsample]],
|
||||
fps=50 // args.video_temp_subsample,
|
||||
)
|
||||
|
||||
# Print current performance after each episode
|
||||
logging.info(f"Solved subtasks: {solved_subtasks}")
|
||||
_calvin_print_performance(episode_solved_subtasks, per_subtask_success)
|
||||
|
||||
# Log final performance
|
||||
logging.info(f"results/avg_num_subtasks: : {np.mean(episode_solved_subtasks)}")
|
||||
for i in range(1, 6):
|
||||
# Compute fraction of episodes that have *at least* i successful subtasks
|
||||
logging.info(
|
||||
f"results/avg_success_len_{i}: {np.sum(episode_solved_subtasks >= i) / len(episode_solved_subtasks)}"
|
||||
)
|
||||
for key in per_subtask_success:
|
||||
logging.info(f"results/avg_success__{key}: {np.mean(per_subtask_success[key])}")
|
||||
|
||||
|
||||
def _get_calvin_tasks_and_reward(num_sequences):
|
||||
conf_dir = pathlib.Path(calvin_env.__file__).absolute().parents[2] / "calvin_models" / "conf"
|
||||
task_cfg = OmegaConf.load(conf_dir / "callbacks/rollout/tasks/new_playtable_tasks.yaml")
|
||||
task_oracle = hydra.utils.instantiate(task_cfg)
|
||||
val_annotations = OmegaConf.load(conf_dir / "annotations/new_playtable_validation.yaml")
|
||||
eval_sequences = get_sequences(num_sequences)
|
||||
return eval_sequences, val_annotations, task_oracle
|
||||
|
||||
|
||||
def _calvin_print_performance(episode_solved_subtasks, per_subtask_success):
|
||||
# Compute avg success rate per task length
|
||||
logging.info("#####################################################")
|
||||
logging.info(f"Avg solved subtasks: {np.mean(episode_solved_subtasks)}\n")
|
||||
|
||||
logging.info("Per sequence_length avg success:")
|
||||
for i in range(1, 6):
|
||||
# Compute fraction of episodes that have *at least* i successful subtasks
|
||||
logging.info(f"{i}: {np.sum(np.array(episode_solved_subtasks) >= i) / len(episode_solved_subtasks) * 100}%")
|
||||
|
||||
logging.info("\n Per subtask avg success:")
|
||||
for key in per_subtask_success:
|
||||
logging.info(f"{key}: \t\t\t {np.mean(per_subtask_success[key]) * 100}%")
|
||||
logging.info("#####################################################")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(main)
|
||||
59
examples/libero/Dockerfile
Normal file
59
examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]
|
||||
39
examples/libero/README.md
Normal file
39
examples/libero/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
export SERVER_ARGS="--env LIBERO"
|
||||
docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
49
examples/libero/compose.yml
Normal file
49
examples/libero/compose.yml
Normal file
@@ -0,0 +1,49 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- DISPLAY=$DISPLAY
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
215
examples/libero/main.py
Normal file
215
examples/libero/main.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
wrist_img = image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
examples/libero/requirements.in
Normal file
11
examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
examples/libero/requirements.txt
Normal file
136
examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
examples/policy_records.ipynb
Normal file
134
examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
32
examples/simple_client/Dockerfile
Normal file
32
examples/simple_client/Dockerfile
Normal file
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/simple_client/main.py"]
|
||||
24
examples/simple_client/README.md
Normal file
24
examples/simple_client/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--example aloha"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py
|
||||
```
|
||||
37
examples/simple_client/compose.yml
Normal file
37
examples/simple_client/compose.yml
Normal file
@@ -0,0 +1,37 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
81
examples/simple_client/main.py
Normal file
81
examples/simple_client/main.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
example: str = "droid"
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
"aloha": _random_observation_aloha,
|
||||
"droid": _random_observation_droid,
|
||||
"calvin": _random_observation_calvin,
|
||||
"libero": _random_observation_libero,
|
||||
}[args.example]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
|
||||
# Send 1 observation to make sure the model is loaded.
|
||||
policy.infer(obs_fn())
|
||||
|
||||
start = time.time()
|
||||
for _ in range(100):
|
||||
policy.infer(obs_fn())
|
||||
end = time.time()
|
||||
|
||||
print(f"Total time taken: {end - start}")
|
||||
# Note that each inference returns many action chunks.
|
||||
print(f"Inference rate: {100 / (end - start)} Hz")
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"qpos": np.ones((14,)),
|
||||
"image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_calvin() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(15),
|
||||
"observation/rgb_static": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"observation/rgb_gripper": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"observation/wrist_image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(main)
|
||||
2
examples/simple_client/requirements.in
Normal file
2
examples/simple_client/requirements.in
Normal file
@@ -0,0 +1,2 @@
|
||||
numpy
|
||||
tyro
|
||||
27
examples/simple_client/requirements.txt
Normal file
27
examples/simple_client/requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
|
||||
backports-cached-property==1.0.2
|
||||
# via tyro
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
eval-type-backport==0.1.3
|
||||
# via tyro
|
||||
markdown-it-py==2.2.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.21.6
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
rich==13.8.1
|
||||
# via tyro
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# markdown-it-py
|
||||
# rich
|
||||
# tyro
|
||||
tyro==0.9.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
25
packages/openpi-client/pyproject.toml
Normal file
25
packages/openpi-client/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "openpi-client"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.7"
|
||||
dependencies = [
|
||||
"dm-tree>=0.1.8",
|
||||
"msgpack>=1.0.5",
|
||||
"numpy>=1.21.6",
|
||||
"pillow>=9.0.0",
|
||||
"tree>=0.2.4",
|
||||
"websockets>=11.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"pytest>=8.3.4",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py37"
|
||||
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,39 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
|
||||
|
||||
class ActionChunkBroker(_base_policy.BasePolicy):
|
||||
"""Wraps a policy to return action chunks one-at-a-time.
|
||||
|
||||
Assumes that the first dimension of all action fields is the chunk size.
|
||||
|
||||
A new inference call to the inner policy is only made when the current
|
||||
list of chunks is exhausted.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
||||
self._policy = policy
|
||||
|
||||
self._action_horizon = action_horizon
|
||||
self._cur_step: int = 0
|
||||
|
||||
self._last_results: Dict[str, np.ndarray] | None = None
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
if self._last_results is None:
|
||||
self._last_results = self._policy.infer(obs)
|
||||
self._cur_step = 0
|
||||
|
||||
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
|
||||
self._cur_step += 1
|
||||
|
||||
if self._cur_step >= self._action_horizon:
|
||||
self._last_results = None
|
||||
|
||||
return results
|
||||
8
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
8
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BasePolicy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def infer(self, obs: Dict) -> Dict:
|
||||
"""Infer actions from observations."""
|
||||
48
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
48
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
||||
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
||||
|
||||
Args:
|
||||
images: A batch of images in [..., height, width, channel] format.
|
||||
height: The target height of the image.
|
||||
width: The target width of the image.
|
||||
method: The interpolation method to use. Default is bilinear.
|
||||
|
||||
Returns:
|
||||
The resized images in [..., height, width, channel].
|
||||
"""
|
||||
# If the images are already the correct size, return them as is.
|
||||
if images.shape[-3:-1] == (height, width):
|
||||
return images
|
||||
|
||||
original_shape = images.shape
|
||||
|
||||
images = images.reshape(-1, *original_shape[-3:])
|
||||
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
||||
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
||||
|
||||
|
||||
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
||||
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
||||
width without distortion by padding with zeros.
|
||||
|
||||
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
||||
"""
|
||||
cur_width, cur_height = image.size
|
||||
if cur_width == width and cur_height == height:
|
||||
return image # No need to resize if the image is already the correct size.
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_image = image.resize((resized_width, resized_height), resample=method)
|
||||
|
||||
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
||||
pad_height = max(0, int((height - resized_height) / 2))
|
||||
pad_width = max(0, int((width - resized_width) / 2))
|
||||
zero_image.paste(resized_image, (pad_width, pad_height))
|
||||
assert zero_image.size == (width, height)
|
||||
return zero_image
|
||||
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi_client.image_tools as image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Adds NumPy array support to msgpack.
|
||||
|
||||
msgpack is good for (de)serializing data over a network for multiple reasons:
|
||||
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
||||
- msgpack is widely used and has good cross-language support
|
||||
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
||||
languages like Python and JavaScript
|
||||
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
||||
than pickle for serializing large arrays using the below strategy
|
||||
|
||||
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
||||
that it falls back to pickle for object arrays.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
||||
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tree
|
||||
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
def _check(expected, actual):
|
||||
if isinstance(expected, np.ndarray):
|
||||
assert expected.shape == actual.shape
|
||||
assert expected.dtype == actual.dtype
|
||||
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
||||
else:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
1, # int
|
||||
1.0, # float
|
||||
"hello", # string
|
||||
np.bool_(True), # boolean scalar
|
||||
np.array([1, 2, 3])[0], # int scalar
|
||||
np.str_("asdf"), # string scalar
|
||||
[1, 2, 3], # list
|
||||
{"key": "value"}, # dict
|
||||
{"key": [1, 2, 3]}, # nested dict
|
||||
np.array(1.0), # 0D array
|
||||
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
||||
np.array(["asdf", "qwer"]), # string array
|
||||
np.array([True, False]), # boolean array
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
||||
np.array([np.nan, np.inf, -np.inf]), # special float values
|
||||
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
||||
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
||||
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
||||
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
||||
],
|
||||
)
|
||||
def test_pack_unpack(data):
|
||||
packed = msgpack_numpy.packb(data)
|
||||
unpacked = msgpack_numpy.unpackb(packed)
|
||||
tree.map_structure(_check, data, unpacked)
|
||||
13
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
13
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
||||
|
||||
Agents receive observations about the state of the world, and return actions
|
||||
to take in response.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
"""Query the agent for the next action."""
|
||||
@@ -0,0 +1,15 @@
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
# TODO: Consider unifying policies and agents.
|
||||
class PolicyAgent(_agent.Agent):
|
||||
"""An agent that uses a policy to determine actions."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
||||
self._policy = policy
|
||||
|
||||
@override
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
return self._policy.infer(observation)
|
||||
@@ -0,0 +1,32 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
"""An Environment represents the robot and the environment it inhabits.
|
||||
|
||||
The primary contract of environments is that they can be queried for observations
|
||||
about their state, and have actions applied to them to change that state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the environment to its initial state.
|
||||
|
||||
This will be called once before starting each episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def done(self) -> bool:
|
||||
"""Allow the environment to signal that the task is done.
|
||||
|
||||
This will be called after each step. It should return `True` if the task is
|
||||
done (either successfully or unsuccessfully), and `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict:
|
||||
"""Query the environment for the current state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_action(self, action: dict) -> None:
|
||||
"""Take an action in the environment."""
|
||||
78
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
78
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The core module orchestrating interactions between key components of the system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: _environment.Environment,
|
||||
agent: _agent.Agent,
|
||||
subscribers: list[_subscriber.Subscriber],
|
||||
max_hz: float = 0,
|
||||
) -> None:
|
||||
self._environment = environment
|
||||
self._agent = agent
|
||||
self._subscribers = subscribers
|
||||
self._max_hz = max_hz
|
||||
|
||||
self._running = False
|
||||
|
||||
def run(self) -> None:
|
||||
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
||||
self._loop()
|
||||
|
||||
def run_in_new_thread(self) -> threading.Thread:
|
||||
"""Runs the runtime loop in a new thread."""
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the runtime loop."""
|
||||
self._running = False
|
||||
|
||||
def _loop(self) -> None:
|
||||
"""The runtime loop."""
|
||||
logging.info("Starting episode...")
|
||||
self._environment.reset()
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_start()
|
||||
|
||||
self._running = True
|
||||
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
||||
last_step_time = time.time()
|
||||
|
||||
while self._running:
|
||||
self._step()
|
||||
|
||||
# Sleep to maintain the desired frame rate
|
||||
now = time.time()
|
||||
dt = now - last_step_time
|
||||
if dt < step_time:
|
||||
time.sleep(step_time - dt)
|
||||
last_step_time = time.time()
|
||||
else:
|
||||
last_step_time = now
|
||||
|
||||
logging.info("Episode completed.")
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_end()
|
||||
|
||||
def _step(self) -> None:
|
||||
"""A single step of the runtime loop."""
|
||||
observation = self._environment.get_observation()
|
||||
action = self._agent.get_action(observation)
|
||||
self._environment.apply_action(action)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_step(observation, action)
|
||||
|
||||
if self._environment.done():
|
||||
self.stop()
|
||||
@@ -0,0 +1,20 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Subscriber(abc.ABC):
|
||||
"""Subscribes to events in the runtime.
|
||||
|
||||
Subscribers can be used to save data, visualize, etc.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_start(self) -> None:
|
||||
"""Called when an episode starts."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
"""Append a step to the episode."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_end(self) -> None:
|
||||
"""Called when an episode ends."""
|
||||
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
from typing_extensions import override
|
||||
import websockets.sync.client
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
||||
"""Implements the Policy interface by communicating with a server over websocket.
|
||||
|
||||
See WebsocketPolicyServer for a corresponding server implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
|
||||
self._uri = f"ws://{host}:{port}"
|
||||
self._packer = msgpack_numpy.Packer()
|
||||
self._ws = self._wait_for_server()
|
||||
|
||||
def _wait_for_server(self) -> websockets.sync.client.ClientConnection:
|
||||
logging.info(f"Waiting for server at {self._uri}...")
|
||||
while True:
|
||||
try:
|
||||
return websockets.sync.client.connect(self._uri, compression=None, max_size=None)
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
data = self._packer.pack(obs)
|
||||
self._ws.send(data)
|
||||
response = self._ws.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
raise RuntimeError(f"Error in inference server:\n{response}")
|
||||
return msgpack_numpy.unpackb(response)
|
||||
123
pyproject.toml
Normal file
123
pyproject.toml
Normal file
@@ -0,0 +1,123 @@
|
||||
[project]
|
||||
name = "openpi"
|
||||
version = "0.1.0"
|
||||
description = "Physical Intelligence open source repo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"augmax>=0.3.4",
|
||||
"dm-tree>=0.1.8",
|
||||
"einops>=0.8.0",
|
||||
"equinox>=0.11.8",
|
||||
"flatbuffers>=24.3.25",
|
||||
"flax==0.10.2",
|
||||
"fsspec[gcs]>=2024.6.0",
|
||||
"gym-aloha>=0.1.1",
|
||||
"imageio>=2.36.1",
|
||||
"jax[cuda12]==0.4.36",
|
||||
"jaxtyping==0.2.36",
|
||||
"lerobot",
|
||||
"ml_collections==1.0.0",
|
||||
"numpy>=1.26.4",
|
||||
"numpydantic>=1.6.6",
|
||||
"opencv-python>=4.10.0.84",
|
||||
"openpi-client",
|
||||
"orbax-checkpoint==0.10.2",
|
||||
"pillow>=11.0.0",
|
||||
"ruff>=0.7.1",
|
||||
"s3fs>=2024.9.0",
|
||||
"sentencepiece>=0.2.0",
|
||||
"torch>=2.5.1",
|
||||
"tqdm-loggable>=0.2",
|
||||
"typing-extensions>=4.12.2",
|
||||
"tyro>=0.9.4",
|
||||
"wandb>=0.19.1",
|
||||
"boto3>=1.35.7",
|
||||
"types-boto3[boto3,s3]>=1.35.7",
|
||||
"filelock>=3.16.1",
|
||||
"beartype>=0.19.0",
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/Physical-Intelligence/openpi"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.3.4",
|
||||
"ruff>=0.8.3",
|
||||
"pre-commit>=4.0.1",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipywidgets>=8.1.5",
|
||||
"matplotlib>=3.10.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv.sources]
|
||||
openpi-client = { workspace = true }
|
||||
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "66f87365988cb5424435ea03b428426b4ede98cb" }
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
extend-exclude = ["docker", "third_party"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
# https://docs.astral.sh/ruff/rules/
|
||||
select = [
|
||||
"B",
|
||||
"C4",
|
||||
"DTZ",
|
||||
"E4",
|
||||
"E7",
|
||||
"E9",
|
||||
"F",
|
||||
"FBT",
|
||||
"FURB",
|
||||
"I",
|
||||
"ICN",
|
||||
"ISC",
|
||||
"LOG",
|
||||
"N",
|
||||
"PD",
|
||||
"PERF",
|
||||
"PIE",
|
||||
"PLC",
|
||||
"PLE",
|
||||
"PLR1",
|
||||
"PLR5",
|
||||
"PLW",
|
||||
"PT",
|
||||
"PTH",
|
||||
"Q",
|
||||
"RET",
|
||||
"RUF",
|
||||
"SIM",
|
||||
"SLF",
|
||||
"T10",
|
||||
"T20",
|
||||
"UP",
|
||||
"W",
|
||||
]
|
||||
ignore = [
|
||||
"F722", # Conflicts with array typing.
|
||||
"T201", # We use print statements.
|
||||
"PD008", # Lots of false positives.
|
||||
]
|
||||
unfixable = [
|
||||
"B905", # Fix defaults to strict=False, which is not what we want.
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
force-single-line = true
|
||||
force-sort-within-sections = true
|
||||
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
|
||||
known-third-party = ["wandb"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
479
scripts/aloha_hd5.py
Normal file
479
scripts/aloha_hd5.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# ruff: noqa
|
||||
"""
|
||||
Script courtesy of Raziel90 https://github.com/huggingface/lerobot/pull/586/files
|
||||
|
||||
Example usage
|
||||
python scripts/aloha_hd5.py --raw-path ~/data/ --dataset-repo-id <hf-username>/<dataset-name> --robot-type <aloha-stationary|aloha-mobile> --fps 50 --video-encoding=false --push=false
|
||||
|
||||
The data will be saved locally the value of the LEROBOT_HOME environment variable. By default this is set to ~/.cache/huggingface/lerobot
|
||||
If you wish to submit the dataset to the hub, you can do so by setting up the hf cli https://huggingface.co/docs/huggingface_hub/en/guides/cli and setting --push=true
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import traceback
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import torch
|
||||
|
||||
|
||||
class AlohaHD5Extractor:
|
||||
TAGS = ["aloha", "robotics", "hdf5"]
|
||||
aloha_stationary = "aloha-stationary"
|
||||
aloha_mobile = "aloha-mobile"
|
||||
|
||||
@staticmethod
|
||||
def get_cameras(hdf5_data: h5py.File):
|
||||
"""
|
||||
Extracts the list of RGB camera keys from the given HDF5 data.
|
||||
Parameters
|
||||
----------
|
||||
hdf5_data : h5py.File
|
||||
The HDF5 file object containing the dataset.
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
A list of keys corresponding to RGB cameras in the dataset.
|
||||
"""
|
||||
|
||||
rgb_cameras = [key for key in hdf5_data["/observations/images"] if "depth" not in key]
|
||||
return rgb_cameras
|
||||
|
||||
@staticmethod
|
||||
def check_format(episode_list: list[str] | list[Path], image_compressed: bool = True):
|
||||
"""
|
||||
Check the format of the given list of HDF5 files.
|
||||
Parameters
|
||||
----------
|
||||
episode_list : list of str or list of Path
|
||||
List of paths to the HDF5 files to be checked.
|
||||
image_compressed : bool, optional
|
||||
Flag indicating whether the images are compressed (default is True).
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the episode_list is empty.
|
||||
If any HDF5 file is missing required keys '/action' or '/observations/qpos'.
|
||||
If the '/action' or '/observations/qpos' keys do not have 2 dimensions.
|
||||
If the number of frames in '/action' and '/observations/qpos' keys do not match.
|
||||
If the number of frames in '/observations/images/{camera}' does not match the number of frames in '/action' and '/observations/qpos'.
|
||||
If the dimensions of images do not match the expected dimensions based on the image_compressed flag.
|
||||
If uncompressed images do not have the expected (h, w, c) format.
|
||||
"""
|
||||
|
||||
if not episode_list:
|
||||
raise ValueError("No hdf5 files found in the raw directory. Make sure they are named 'episode_*.hdf5'")
|
||||
for episode_path in episode_list:
|
||||
with h5py.File(episode_path, "r") as data:
|
||||
if not all(key in data for key in ["/action", "/observations/qpos"]):
|
||||
raise ValueError(
|
||||
"Missing required keys in the hdf5 file. Make sure the keys '/action' and '/observations/qpos' are present."
|
||||
)
|
||||
|
||||
if not data["/action"].ndim == data["/observations/qpos"].ndim == 2:
|
||||
raise ValueError("The '/action' and '/observations/qpos' keys should have both 2 dimensions.")
|
||||
|
||||
if (num_frames := data["/action"].shape[0]) != data["/observations/qpos"].shape[0]:
|
||||
raise ValueError(
|
||||
"The '/action' and '/observations/qpos' keys should have the same number of frames."
|
||||
)
|
||||
|
||||
for camera in AlohaHD5Extractor.get_cameras(data):
|
||||
if num_frames != data[f"/observations/images/{camera}"].shape[0]:
|
||||
raise ValueError(
|
||||
f"The number of frames in '/observations/images/{camera}' should be the same as in '/action' and '/observations/qpos' keys."
|
||||
)
|
||||
|
||||
expected_dims = 2 if image_compressed else 4
|
||||
if data[f"/observations/images/{camera}"].ndim != expected_dims:
|
||||
raise ValueError(
|
||||
f"Expect {expected_dims} dimensions for {'compressed' if image_compressed else 'uncompressed'} images but {data[f'/observations/images/{camera}'].ndim} provided."
|
||||
)
|
||||
if not image_compressed:
|
||||
b, h, w, c = data[f"/observations/images/{camera}"].shape
|
||||
if not c < h and c < w:
|
||||
raise ValueError(f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided.")
|
||||
|
||||
@staticmethod
|
||||
def extract_episode_frames(
|
||||
episode_path: str | Path, features: dict[str, dict], image_compressed: bool
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Extract frames from an episode stored in an HDF5 file.
|
||||
Parameters
|
||||
----------
|
||||
episode_path : str or Path
|
||||
Path to the HDF5 file containing the episode data.
|
||||
features : dict of str to dict
|
||||
Dictionary where keys are feature identifiers and values are dictionaries with feature details.
|
||||
image_compressed : bool
|
||||
Flag indicating whether the images are stored in a compressed format.
|
||||
Returns
|
||||
-------
|
||||
list of dict of str to torch.Tensor
|
||||
List of frames, where each frame is a dictionary mapping feature identifiers to tensors.
|
||||
"""
|
||||
|
||||
frames = []
|
||||
with h5py.File(episode_path, "r") as file:
|
||||
for frame_idx in range(file["/action"].shape[0]):
|
||||
frame = {}
|
||||
for feature_id in features:
|
||||
feature_name_hd5 = feature_id.replace(".", "/")
|
||||
if "images" in feature_id.split("."):
|
||||
image = (
|
||||
(file[feature_name_hd5][frame_idx])
|
||||
if not image_compressed
|
||||
else cv2.imdecode(file[feature_name_hd5][frame_idx], 1)
|
||||
)
|
||||
frame[feature_id] = torch.from_numpy(image.transpose(2, 0, 1))
|
||||
else:
|
||||
frame[feature_id] = torch.from_numpy(file[feature_name_hd5][frame_idx])
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
@staticmethod
|
||||
def define_features(
|
||||
hdf5_file_path: Path, image_compressed: bool = True, encode_as_video: bool = True
|
||||
) -> dict[str, dict]:
|
||||
"""
|
||||
Define features from an HDF5 file.
|
||||
Parameters
|
||||
----------
|
||||
hdf5_file_path : Path
|
||||
The path to the HDF5 file.
|
||||
image_compressed : bool, optional
|
||||
Whether the images are compressed, by default True.
|
||||
encode_as_video : bool, optional
|
||||
Whether to encode images as video or as images, by default True.
|
||||
Returns
|
||||
-------
|
||||
dict[str, dict]
|
||||
A dictionary where keys are topic names and values are dictionaries
|
||||
containing feature information such as dtype, shape, and names.
|
||||
"""
|
||||
|
||||
# Initialize lists to store topics and features
|
||||
topics = []
|
||||
features = {}
|
||||
|
||||
# Open the HDF5 file
|
||||
with h5py.File(hdf5_file_path, "r") as hdf5_file:
|
||||
# Collect all dataset names in the HDF5 file
|
||||
hdf5_file.visititems(lambda name, obj: topics.append(name) if isinstance(obj, h5py.Dataset) else None)
|
||||
|
||||
# Iterate over each topic to define its features
|
||||
for topic in topics:
|
||||
# If the topic is an image, define it as a video feature
|
||||
if "images" in topic.split("/"):
|
||||
sample = hdf5_file[topic][0]
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": "video" if encode_as_video else "image",
|
||||
"shape": cv2.imdecode(hdf5_file[topic][0], 1).transpose(2, 0, 1).shape
|
||||
if image_compressed
|
||||
else sample.shape,
|
||||
"names": [
|
||||
"channel",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
# Skip compressed length topics
|
||||
elif "compress_len" in topic.split("/"):
|
||||
continue
|
||||
# Otherwise, define it as a regular feature
|
||||
else:
|
||||
features[topic.replace("/", ".")] = {
|
||||
"dtype": str(hdf5_file[topic][0].dtype),
|
||||
"shape": (topic_shape := hdf5_file[topic][0].shape),
|
||||
"names": [f"{topic.split('/')[-1]}_{k}" for k in range(topic_shape[0])],
|
||||
}
|
||||
# Return the defined features
|
||||
return features
|
||||
|
||||
|
||||
class DatasetConverter:
|
||||
"""
|
||||
A class to convert datasets to Lerobot format.
|
||||
Parameters
|
||||
----------
|
||||
raw_path : Path or str
|
||||
The path to the raw dataset.
|
||||
dataset_repo_id : str
|
||||
The repository ID where the dataset will be stored.
|
||||
fps : int
|
||||
Frames per second for the dataset.
|
||||
robot_type : str, optional
|
||||
The type of robot, by default "".
|
||||
encode_as_videos : bool, optional
|
||||
Whether to encode images as videos, by default True.
|
||||
image_compressed : bool, optional
|
||||
Whether the images are compressed, by default True.
|
||||
image_writer_processes : int, optional
|
||||
Number of processes for writing images, by default 0.
|
||||
image_writer_threads : int, optional
|
||||
Number of threads for writing images, by default 0.
|
||||
Methods
|
||||
-------
|
||||
extract_episode(episode_path, task_description='')
|
||||
Extracts frames from a single episode and saves it with a description.
|
||||
extract_episodes(episode_description='')
|
||||
Extracts frames from all episodes and saves them with a description.
|
||||
push_dataset_to_hub(dataset_tags=None, private=False, push_videos=True, license="apache-2.0")
|
||||
Pushes the dataset to the Hugging Face Hub.
|
||||
init_lerobot_dataset()
|
||||
Initializes the Lerobot dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
raw_path: Path | str,
|
||||
dataset_repo_id: str,
|
||||
fps: int,
|
||||
robot_type: str = "",
|
||||
encode_as_videos: bool = True,
|
||||
image_compressed: bool = True,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
):
|
||||
self.raw_path = raw_path if isinstance(raw_path, Path) else Path(raw_path)
|
||||
self.dataset_repo_id = dataset_repo_id
|
||||
self.fps = fps
|
||||
self.robot_type = robot_type
|
||||
self.image_compressed = image_compressed
|
||||
self.image_writer_threads = image_writer_threads
|
||||
self.image_writer_processes = image_writer_processes
|
||||
self.encode_as_videos = encode_as_videos
|
||||
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# Add console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(asctime)s - [%(name)s] - %(message)s")
|
||||
console_handler.setFormatter(formatter)
|
||||
self.logger.addHandler(console_handler)
|
||||
|
||||
self.logger.info(f"{'-'*10} Aloha HD5 -> Lerobot Converter {'-'*10}")
|
||||
self.logger.info(f"Processing Aloha HD5 dataset from {self.raw_path}")
|
||||
self.logger.info(f"Dataset will be stored in {self.dataset_repo_id}")
|
||||
self.logger.info(f"FPS: {self.fps}")
|
||||
self.logger.info(f"Robot type: {self.robot_type}")
|
||||
self.logger.info(f"Image compressed: {self.image_compressed}")
|
||||
self.logger.info(f"Encoding images as videos: {self.encode_as_videos}")
|
||||
self.logger.info(f"#writer processes: {self.image_writer_processes}")
|
||||
self.logger.info(f"#writer threads: {self.image_writer_threads}")
|
||||
|
||||
self.episode_list = list(self.raw_path.glob("episode_*.hdf5"))
|
||||
AlohaHD5Extractor.check_format(self.episode_list, image_compressed=self.image_compressed)
|
||||
self.features = AlohaHD5Extractor.define_features(
|
||||
self.episode_list[0],
|
||||
image_compressed=self.image_compressed,
|
||||
encode_as_video=self.encode_as_videos,
|
||||
)
|
||||
|
||||
def extract_episode(self, episode_path, task_description: str = ""):
|
||||
"""
|
||||
Extracts frames from an episode and saves them to the dataset.
|
||||
Parameters
|
||||
----------
|
||||
episode_path : str
|
||||
The path to the episode file.
|
||||
task_description : str, optional
|
||||
A description of the task associated with the episode (default is an empty string).
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
|
||||
for frame in AlohaHD5Extractor.extract_episode_frames(episode_path, self.features, self.image_compressed):
|
||||
self.dataset.add_frame(frame)
|
||||
self.logger.info(f"Saving Episode with Description: {task_description} ...")
|
||||
self.dataset.save_episode(task=task_description)
|
||||
|
||||
def extract_episodes(self, episode_description: str = ""):
|
||||
"""
|
||||
Extracts episodes from the episode list and processes them.
|
||||
Parameters
|
||||
----------
|
||||
episode_description : str, optional
|
||||
A description of the task to be passed to the extract_episode method (default is '').
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
If an error occurs during the processing of an episode, it will be caught and printed.
|
||||
Notes
|
||||
-----
|
||||
After processing all episodes, the dataset is consolidated.
|
||||
"""
|
||||
|
||||
for episode_path in self.episode_list:
|
||||
try:
|
||||
self.extract_episode(episode_path, task_description=episode_description)
|
||||
except Exception as e:
|
||||
print(f"Error processing episode {episode_path}", f"{e}")
|
||||
traceback.print_exc()
|
||||
continue
|
||||
self.dataset.consolidate()
|
||||
|
||||
def push_dataset_to_hub(
|
||||
self,
|
||||
dataset_tags: list[str] | None = None,
|
||||
private: bool = False,
|
||||
push_videos: bool = True,
|
||||
license: str | None = "apache-2.0",
|
||||
):
|
||||
"""
|
||||
Pushes the dataset to the Hugging Face Hub.
|
||||
Parameters
|
||||
----------
|
||||
dataset_tags : list of str, optional
|
||||
A list of tags to associate with the dataset on the Hub. Default is None.
|
||||
private : bool, optional
|
||||
If True, the dataset will be private. Default is False.
|
||||
push_videos : bool, optional
|
||||
If True, videos will be pushed along with the dataset. Default is True.
|
||||
license : str, optional
|
||||
The license under which the dataset is released. Default is "apache-2.0".
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
"""
|
||||
|
||||
self.logger.info(f"Pushing dataset to Hugging Face Hub. ID: {self.dataset_repo_id} ...")
|
||||
self.dataset.push_to_hub(
|
||||
tags=dataset_tags,
|
||||
license=license,
|
||||
push_videos=push_videos,
|
||||
private=private,
|
||||
)
|
||||
|
||||
def init_lerobot_dataset(self):
|
||||
"""
|
||||
Initializes the LeRobot dataset.
|
||||
This method cleans the cache if the dataset already exists and then creates a new LeRobot dataset.
|
||||
Returns
|
||||
-------
|
||||
LeRobotDataset
|
||||
The initialized LeRobot dataset.
|
||||
"""
|
||||
|
||||
# Clean the cache if the dataset already exists
|
||||
if os.path.exists(LEROBOT_HOME / self.dataset_repo_id):
|
||||
shutil.rmtree(LEROBOT_HOME / self.dataset_repo_id)
|
||||
self.dataset = LeRobotDataset.create(
|
||||
repo_id=self.dataset_repo_id,
|
||||
fps=self.fps,
|
||||
robot_type=self.robot_type,
|
||||
features=self.features,
|
||||
image_writer_threads=self.image_writer_threads,
|
||||
image_writer_processes=self.image_writer_processes,
|
||||
)
|
||||
|
||||
return self.dataset
|
||||
|
||||
|
||||
def str2bool(value):
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
value = value.lower()
|
||||
if value in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
if value in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Convert Aloha HD5 dataset and push to Hugging Face hub.
|
||||
This script processes raw HDF5 files from the Aloha dataset, converts them into a specified format,
|
||||
and optionally uploads the dataset to the Hugging Face hub.
|
||||
Parameters
|
||||
----------
|
||||
--raw-path : Path
|
||||
Directory containing the raw HDF5 files.
|
||||
--dataset-repo-id : str
|
||||
Repository ID where the dataset will be stored.
|
||||
--fps : int
|
||||
Frames per second for the dataset.
|
||||
--robot-type : str, optional
|
||||
Type of robot, either "aloha-stationary" or "aloha-mobile". Default is "aloha-stationary".
|
||||
--private : bool, optional
|
||||
Set to True to make the dataset private. Default is False.
|
||||
--push-videos : bool, optional
|
||||
Set to True to push videos to the hub. Default is True.
|
||||
--license : str, optional
|
||||
License for the dataset. Default is "apache-2.0".
|
||||
--image-compressed : bool, optional
|
||||
Set to True if the images are compressed. Default is True.
|
||||
--video-encoding : bool, optional
|
||||
Set to True to encode images as videos. Default is True.
|
||||
--nproc : int, optional
|
||||
Number of image writer processes. Default is 10.
|
||||
--nthreads : int, optional
|
||||
Number of image writer threads. Default is 5.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert Aloha HD5 dataset and push to Hugging Face hub.")
|
||||
parser.add_argument("--raw-path", type=Path, required=True, help="Directory containing the raw hdf5 files.")
|
||||
parser.add_argument(
|
||||
"--dataset-repo-id", type=str, required=True, help="Repository ID where the dataset will be stored."
|
||||
)
|
||||
parser.add_argument("--fps", type=int, required=True, help="Frames per second for the dataset.")
|
||||
parser.add_argument(
|
||||
"--description", type=str, help="Description of the dataset.", default="Aloha recorded dataset."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--robot-type",
|
||||
type=str,
|
||||
choices=["aloha-stationary", "aloha-mobile"],
|
||||
default="aloha-stationary",
|
||||
help="Type of robot.",
|
||||
)
|
||||
parser.add_argument("--private", type=str2bool, default=False, help="Set to True to make the dataset private.")
|
||||
parser.add_argument("--push", type=str2bool, default=True, help="Set to True to push videos to the hub.")
|
||||
parser.add_argument("--license", type=str, default="apache-2.0", help="License for the dataset.")
|
||||
parser.add_argument(
|
||||
"--image-compressed", type=str2bool, default=True, help="Set to True if the images are compressed."
|
||||
)
|
||||
parser.add_argument("--video-encoding", type=str2bool, default=True, help="Set to True to encode images as videos.")
|
||||
|
||||
parser.add_argument("--nproc", type=int, default=10, help="Number of image writer processes.")
|
||||
parser.add_argument("--nthreads", type=int, default=5, help="Number of image writer threads.")
|
||||
|
||||
args = parser.parse_args()
|
||||
print(
|
||||
args.video_encoding,
|
||||
"-------------------------------------------------------------------------------------------------------",
|
||||
)
|
||||
|
||||
converter = DatasetConverter(
|
||||
raw_path=args.raw_path,
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
fps=args.fps,
|
||||
robot_type=args.robot_type,
|
||||
image_compressed=args.image_compressed,
|
||||
encode_as_videos=args.video_encoding,
|
||||
image_writer_processes=args.nproc,
|
||||
image_writer_threads=args.nthreads,
|
||||
)
|
||||
converter.init_lerobot_dataset()
|
||||
converter.extract_episodes(episode_description=args.description)
|
||||
|
||||
if args.push:
|
||||
converter.push_dataset_to_hub(
|
||||
dataset_tags=AlohaHD5Extractor.TAGS, private=args.private, push_videos=True, license=args.license
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
29
scripts/compose.yml
Normal file
29
scripts/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: scripts/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
- ~/.aws/:/root/.aws/
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
67
scripts/compute_norm_stats.py
Normal file
67
scripts/compute_norm_stats.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config metadata directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
|
||||
|
||||
def create_dataset(config: _config.TrainConfig) -> tuple[str, _data_loader.Dataset]:
|
||||
model = config.create_model()
|
||||
data_config = config.data.create(config.metadata_dir, model)
|
||||
if data_config.repo_id is None:
|
||||
raise ValueError("Data config must have a repo_id")
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
_data_loader.create_dataset(data_config, model),
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
],
|
||||
)
|
||||
return data_config.repo_id, dataset
|
||||
|
||||
|
||||
def main(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
repo_id, dataset = create_dataset(config)
|
||||
|
||||
num_frames = len(dataset)
|
||||
shuffle = False
|
||||
|
||||
if max_frames is not None and max_frames < num_frames:
|
||||
num_frames = max_frames
|
||||
shuffle = True
|
||||
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=1,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_frames,
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
||||
for key in keys:
|
||||
values = np.asarray(batch[key][0])
|
||||
stats[key].update(values.reshape(-1, values.shape[-1]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
output_path = config.metadata_dir / repo_id
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
37
scripts/install_docker_ubuntu22.sh
Executable file
37
scripts/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
17
scripts/install_nvidia_container_toolkit.sh
Executable file
17
scripts/install_nvidia_container_toolkit.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
34
scripts/serve_policy.Dockerfile
Normal file
34
scripts/serve_policy.Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
243
scripts/serve_policy.py
Normal file
243
scripts/serve_policy.py
Normal file
@@ -0,0 +1,243 @@
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import exported as _exported
|
||||
from openpi.models import model as _model
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import calvin_policy
|
||||
from openpi.policies import droid_policy
|
||||
from openpi.policies import libero_policy
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
CALVIN = "calvin"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Exported:
|
||||
"""Load an exported checkpoint."""
|
||||
|
||||
# Checkpoint directory (e.g., "s3://openpi-assets/exported/pi0_aloha/model").
|
||||
dir: str
|
||||
# Processor name to load the norm stats from. If not provided, the default processor for the environment will be used.
|
||||
processor: str | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Exported | None = None
|
||||
|
||||
# If provided, overrides the default prompt for the policy.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
|
||||
def repack_from_env(env: EnvMode) -> transforms.Group:
|
||||
"""Creates environment specific repack transforms."""
|
||||
# TODO(ury): Move this to the runtime.
|
||||
match env:
|
||||
case EnvMode.ALOHA:
|
||||
return transforms.Group(
|
||||
inputs=[aloha_policy.ActInputsRepack()],
|
||||
outputs=[aloha_policy.ActOutputsRepack()],
|
||||
)
|
||||
case EnvMode.ALOHA_SIM:
|
||||
return transforms.Group(
|
||||
inputs=[aloha_policy.ActInputsRepack()],
|
||||
outputs=[aloha_policy.ActOutputsRepack()],
|
||||
)
|
||||
case _:
|
||||
return transforms.Group()
|
||||
|
||||
|
||||
# Default exported models.
|
||||
DEFAULT_EXPORTED: dict[EnvMode, Exported] = {
|
||||
EnvMode.ALOHA: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_aloha/model",
|
||||
processor="trossen_biarm_single_base_cam_24dim",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_aloha_sim/model",
|
||||
processor="huggingface_aloha_sim_transfer_cube",
|
||||
),
|
||||
EnvMode.DROID: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_droid/model",
|
||||
processor="openx_droid",
|
||||
),
|
||||
EnvMode.CALVIN: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_calvin/model",
|
||||
processor="calvin",
|
||||
),
|
||||
EnvMode.LIBERO: Exported(
|
||||
dir="s3://openpi-assets/exported/pi0_libero/model",
|
||||
processor="libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(
|
||||
env: EnvMode, *, default_prompt: str | None = None, exported: Exported | None = None
|
||||
) -> _policy.Policy:
|
||||
model: _model.BaseModel
|
||||
config: _policy_config.PolicyConfig
|
||||
|
||||
default_exported = DEFAULT_EXPORTED[env]
|
||||
if exported:
|
||||
checkpoint_dir = exported.dir
|
||||
processor = exported.processor or default_exported.processor
|
||||
else:
|
||||
checkpoint_dir = default_exported.dir
|
||||
processor = default_exported.processor
|
||||
assert processor, "Default processor must be always set"
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = _exported.PiModel.from_checkpoint(checkpoint_dir)
|
||||
|
||||
def make_policy_config(
|
||||
input_layers: Sequence[transforms.DataTransformFn],
|
||||
output_layers: Sequence[transforms.DataTransformFn],
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
sample_kwargs = sample_kwargs or {"num_steps": 10}
|
||||
return _policy_config.PolicyConfig(
|
||||
model=model,
|
||||
norm_stats=model.norm_stats(processor),
|
||||
default_prompt=default_prompt,
|
||||
input_layers=input_layers,
|
||||
output_layers=output_layers,
|
||||
sample_kwargs=sample_kwargs,
|
||||
)
|
||||
|
||||
logging.info("Creating policy...")
|
||||
match env:
|
||||
case EnvMode.ALOHA:
|
||||
delta_action_mask = _policy_config.make_bool_mask(6, -1, 6, -1)
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
aloha_policy.AlohaInputs(
|
||||
action_dim=model.action_dim,
|
||||
delta_action_mask=delta_action_mask,
|
||||
adapt_to_pi=True,
|
||||
),
|
||||
],
|
||||
output_layers=[
|
||||
aloha_policy.AlohaOutputs(
|
||||
delta_action_mask=delta_action_mask,
|
||||
adapt_to_pi=True,
|
||||
),
|
||||
aloha_policy.ActOutputsRepack(),
|
||||
],
|
||||
)
|
||||
case EnvMode.ALOHA_SIM:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
aloha_policy.AlohaInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
aloha_policy.AlohaOutputs(),
|
||||
aloha_policy.ActOutputsRepack(),
|
||||
],
|
||||
)
|
||||
case EnvMode.DROID:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
droid_policy.DroidInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
droid_policy.DroidOutputs(),
|
||||
transforms.SubsampleActions(stride=5),
|
||||
],
|
||||
)
|
||||
case EnvMode.CALVIN:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
calvin_policy.CalvinInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
calvin_policy.CalvinOutputs(),
|
||||
],
|
||||
)
|
||||
case EnvMode.LIBERO:
|
||||
config = make_policy_config(
|
||||
input_layers=[
|
||||
libero_policy.LiberoInputs(action_dim=model.action_dim),
|
||||
],
|
||||
output_layers=[
|
||||
libero_policy.LiberoOutputs(),
|
||||
],
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Unknown environment mode: {env}")
|
||||
|
||||
return _policy_config.create_policy(config)
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config),
|
||||
args.policy.dir,
|
||||
repack_transforms=repack_from_env(args.env),
|
||||
default_prompt=args.default_prompt,
|
||||
)
|
||||
case Exported():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt, exported=args.policy)
|
||||
case None:
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
logging.info("Creating server...")
|
||||
server = websocket_policy_server.WebsocketPolicyServer(policy=policy, host="0.0.0.0", port=args.port)
|
||||
|
||||
logging.info("Serving...")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
284
scripts/train.py
Normal file
284
scripts/train.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
from flax.training import common_utils
|
||||
import jax
|
||||
import jax._src.tree_util as private_tree_util
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.common as _common
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(weight_loader: _weight_loaders.WeightLoader, params: at.Params) -> at.Params:
|
||||
"""Runs the weight loader and validates that the params structure, shapes, and dtypes are unchanged."""
|
||||
new_params = weight_loader.load(jax.tree.map(lambda x: x, params))
|
||||
|
||||
if errors := list(private_tree_util.equality_errors(params, new_params)):
|
||||
raise ValueError(
|
||||
"Weight loading changed the params structure:\n"
|
||||
+ (
|
||||
"\n".join(
|
||||
f" - {jax.tree_util.keystr(path)} changed from {thing1} to {thing2}, so {explanation}.\n"
|
||||
for path, thing1, thing2, explanation in errors
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def check(kp, x, y):
|
||||
if (x := jax.ShapeDtypeStruct(x.shape, x.dtype)) != (y := jax.ShapeDtypeStruct(y.shape, y.dtype)):
|
||||
raise ValueError(
|
||||
f"Weight loading changed the params structure: expected {y}, got {x} at {jax.tree_util.keystr(kp)}"
|
||||
)
|
||||
|
||||
jax.tree_util.tree_map_with_path(check, params, new_params)
|
||||
|
||||
return new_params
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig,
|
||||
model: _model.Model,
|
||||
init_rng: at.KeyArrayLike,
|
||||
batch: tuple[_common.Observation, _common.Actions],
|
||||
mesh: jax.sharding.Mesh,
|
||||
data_sharding: jax.sharding.Sharding,
|
||||
*,
|
||||
resume: bool,
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
weight_decay_mask = None
|
||||
freeze_mask = None
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask, freeze_mask)
|
||||
|
||||
def init(
|
||||
rng: at.KeyArrayLike,
|
||||
data: tuple[_common.Observation, _common.Actions],
|
||||
params_sharding: jax.sharding.Sharding | None = None,
|
||||
) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
observation, actions = data
|
||||
params = model.init_params(model_rng, observation, actions)
|
||||
# jax.experimental.io_callback raises spmd partitioning warnings, setting constraints
|
||||
# to replicate params to avoid the warnings. the returned train state will be sharded still
|
||||
# since fsdp sharding is specified as output_sharding when jitting this function.
|
||||
if params_sharding is not None:
|
||||
params = jax.lax.with_sharding_constraint(params, params_sharding)
|
||||
params = jax.experimental.io_callback(
|
||||
partial(_load_weights_and_validate, config.weight_loader),
|
||||
params,
|
||||
params,
|
||||
ordered=True,
|
||||
)
|
||||
if params_sharding is not None:
|
||||
params = jax.lax.with_sharding_constraint(params, params_sharding)
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
opt_state=tx.init(params),
|
||||
tx=tx,
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng, batch)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
in_shardings=(replicated_sharding, data_sharding),
|
||||
out_shardings=state_sharding,
|
||||
static_argnums=(2,),
|
||||
)(init_rng, batch, replicated_sharding)
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
model: _model.Model,
|
||||
batch: tuple[_common.Observation, _common.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
def loss_fn(params: at.Params, rng: at.KeyArrayLike, observation: _common.Observation, actions: _common.Actions):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, params=params, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
loss, grads = jax.value_and_grad(loss_fn)(state.params, train_rng, observation, actions)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, state.params)
|
||||
new_params = optax.apply_updates(state.params, updates)
|
||||
|
||||
new_state = state.replace(step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = new_state.replace(
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
)
|
||||
)
|
||||
|
||||
kernel_mask = training_utils.mask_from_regex(r".*\['kernel'\]", state.params)
|
||||
kernel_params = jax.tree.map(lambda p, m: p if m else None, state.params, kernel_mask)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads), # TODO: do not compute norm for frozen params
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
if jax.device_count() % config.fsdp_devices != 0:
|
||||
raise ValueError(
|
||||
f"Number of devices {jax.device_count()} must be divisible by the number of FSDP devices {config.fsdp_devices}."
|
||||
)
|
||||
mesh_shape = (jax.device_count() // config.fsdp_devices, config.fsdp_devices)
|
||||
mesh = jax.make_mesh(mesh_shape, ("batch", "model"))
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(("batch", "model")))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_interval=config.keep_interval,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
model = config.create_model()
|
||||
|
||||
data_loader = _data_loader.create_data_loader(
|
||||
config,
|
||||
model,
|
||||
sharding=data_sharding,
|
||||
num_workers=config.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
train_state, train_state_sharding = init_train_state(
|
||||
config, model, init_rng, batch, mesh, data_sharding, resume=resuming
|
||||
)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
train_step,
|
||||
in_shardings=(replicated_sharding, train_state_sharding, None, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
train_state, info = ptrain_step(train_rng, train_state, model, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
27
scripts/train_test.py
Normal file
27
scripts/train_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import dataclasses
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=tmp_path / "checkpoint",
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
0
src/openpi/__init__.py
Normal file
0
src/openpi/__init__.py
Normal file
0
src/openpi/models/__init__.py
Normal file
0
src/openpi/models/__init__.py
Normal file
77
src/openpi/models/common.py
Normal file
77
src/openpi/models/common.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
from typing import TypeAlias
|
||||
|
||||
from flax import struct
|
||||
import flax.linen as nn
|
||||
import numpy as np
|
||||
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class Observation:
|
||||
"""Holds observations, i.e., inputs to the model."""
|
||||
|
||||
# Images, in [-1, 1] float32.
|
||||
images: dict[str, at.Float[at.Array, "*b h w c"]]
|
||||
# Image masks, with same keys as images.
|
||||
image_masks: dict[str, at.Bool[at.Array, "*b"]]
|
||||
# Low-dimensional robot state.
|
||||
state: at.Float[at.Array, "*b s"]
|
||||
# Tokenized prompt.
|
||||
tokenized_prompt: at.Int[at.Array, "*b l"] | None = None
|
||||
# Tokenized prompt mask.
|
||||
tokenized_prompt_mask: at.Int[at.Array, "*b l"] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: at.PyTree[at.ArrayLike]) -> "Observation":
|
||||
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
|
||||
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
|
||||
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
|
||||
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
|
||||
# If images are uint8, convert them to [-1, 1] float32.
|
||||
for key in data["image"]:
|
||||
if data["image"][key].dtype == np.uint8:
|
||||
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
return cls(
|
||||
images=data["image"],
|
||||
image_masks=data["image_mask"],
|
||||
state=data["state"],
|
||||
tokenized_prompt=data.get("tokenized_prompt"),
|
||||
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> at.PyTree[at.ArrayLike]:
|
||||
"""Convert the Observation to a nested dict."""
|
||||
result = dataclasses.asdict(self)
|
||||
# TODO(ury): This is awkward. Adjust the names to be the same.
|
||||
result["image"] = result.pop("images")
|
||||
result["image_mask"] = result.pop("image_masks")
|
||||
return result
|
||||
|
||||
|
||||
Actions: TypeAlias = at.Float[at.ArrayLike, "*b ah ad"]
|
||||
|
||||
|
||||
class BaseModule(nn.Module, abc.ABC):
|
||||
@at.typecheck
|
||||
@abc.abstractmethod
|
||||
def compute_loss(
|
||||
self,
|
||||
obs: Observation,
|
||||
target_actions: Actions,
|
||||
*,
|
||||
timestep: at.Float[at.Array, " b"] | None = None,
|
||||
) -> at.Float[at.Array, "b ah"]: ...
|
||||
|
||||
@at.typecheck
|
||||
@abc.abstractmethod
|
||||
def sample_actions(
|
||||
self,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
obs: Observation,
|
||||
**sample_kwargs,
|
||||
) -> Actions: ...
|
||||
292
src/openpi/models/exported.py
Normal file
292
src/openpi/models/exported.py
Normal file
@@ -0,0 +1,292 @@
|
||||
"""Functionality to handle internal pi checkpoints.
|
||||
|
||||
Used to test internal pi checkpoints and provides utilities to convert them to openpi checkpoints.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
import flax.serialization
|
||||
import flax.struct as struct
|
||||
import jax
|
||||
import jax.export
|
||||
import jax.numpy as jnp
|
||||
import orbax.checkpoint as ocp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import common
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import image_tools
|
||||
from openpi.shared import normalize as _normalize
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
def convert_to_openpi(
|
||||
ckpt_dir: pathlib.Path | str, processor: str, out_dir: pathlib.Path | str, param_path: str = "decoder"
|
||||
) -> None:
|
||||
"""Convert a monopi checkpoint to an openpi checkpoint."""
|
||||
out_dir = pathlib.Path(out_dir)
|
||||
if out_dir.exists():
|
||||
raise FileExistsError(f"Output directory already exists: {out_dir}")
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load params and norm stats.
|
||||
ckpt_dir = download.maybe_download(str(ckpt_dir))
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices("cpu")[0])
|
||||
params = _load_params(ckpt_dir, sharding=sharding)
|
||||
norm_stats = _import_norm_stats(ckpt_dir, processor)
|
||||
|
||||
for part in param_path.split("/"):
|
||||
if part not in params:
|
||||
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
|
||||
params = params[part]
|
||||
|
||||
# Load the monopi model.
|
||||
# Save params.
|
||||
ckpt = ocp.StandardCheckpointer()
|
||||
ckpt.save(out_dir / "params", {"params": params})
|
||||
ckpt.wait_until_finished()
|
||||
|
||||
# Save norm stats.
|
||||
_normalize.save(out_dir / "assets", norm_stats)
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class PiModel(_model.BaseModel):
|
||||
"""A model loaded from a monopi checkpoint model directory."""
|
||||
|
||||
params: at.Params
|
||||
|
||||
exported: jax.export.Exported = struct.field(pytree_node=False)
|
||||
example_spec: Any = struct.field(pytree_node=False)
|
||||
sample_spec: Any = struct.field(pytree_node=False)
|
||||
ckpt_dir: pathlib.Path = struct.field(pytree_node=False)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, ckpt_dir: pathlib.Path | str) -> "PiModel":
|
||||
"""Load a model from a monopi model checkpoint directory. Must point at the "model" sub-directory."""
|
||||
ckpt_dir = download.maybe_download(str(ckpt_dir))
|
||||
with (ckpt_dir / "graph").open("rb") as f:
|
||||
exported = jax.export.deserialize(f.read())
|
||||
|
||||
input_spec = jax.tree.unflatten(exported.in_tree, exported.in_avals)[0]
|
||||
params = _load_params(ckpt_dir, input_spec[0])
|
||||
example_spec = input_spec[2]
|
||||
sample_spec = input_spec[3]
|
||||
|
||||
# Extract the action properties from the output spec.
|
||||
output_spec = jax.tree.unflatten(exported.out_tree, exported.out_avals)
|
||||
actions_spec = output_spec["actions"]
|
||||
action_horizon, action_dim = actions_spec.shape
|
||||
|
||||
max_token_len = example_spec["prompt_tokens"].shape[-1]
|
||||
|
||||
return cls(
|
||||
params=params,
|
||||
exported=exported,
|
||||
example_spec=example_spec,
|
||||
sample_spec=sample_spec,
|
||||
ckpt_dir=ckpt_dir,
|
||||
action_horizon=action_horizon,
|
||||
action_dim=action_dim,
|
||||
max_token_len=max_token_len,
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
@override
|
||||
def sample_actions(self, rng: at.KeyArrayLike, observation: common.Observation, **sample_kwargs) -> common.Actions:
|
||||
if observation.state.ndim == 2 and observation.state.shape[0] != 1:
|
||||
raise ValueError("Only batch_size=1 is supported.")
|
||||
|
||||
# Convert to the example format.
|
||||
example = _obs_to_example(observation, self.example_spec)
|
||||
example = _unbatch(example)
|
||||
|
||||
# Resize the input images if needed.
|
||||
def resize_if_needed(key, image):
|
||||
target_shape = self.example_spec["image"][key].shape
|
||||
if len(target_shape) == 3 and image.shape != target_shape:
|
||||
return image_tools.resize_with_pad(image, *target_shape[-3:-1])
|
||||
return image
|
||||
|
||||
example["image"] = {key: resize_if_needed(key, value) for key, value in example["image"].items()}
|
||||
|
||||
if set(sample_kwargs) != set(self.sample_spec):
|
||||
raise ValueError(
|
||||
f"Sample args {list(sample_kwargs)} do not match the expected args {list(self.sample_spec)}"
|
||||
)
|
||||
|
||||
rng_data = jax.random.key_data(rng)
|
||||
result = self.exported.call(self.params, rng_data, example, sample_kwargs)
|
||||
|
||||
return _make_batch(result)["actions"]
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
actions: common.Actions,
|
||||
*,
|
||||
train: bool = False,
|
||||
params: at.Params | None = None,
|
||||
) -> at.Float[at.Array, "*b ah"]:
|
||||
raise NotImplementedError("Not implemented.")
|
||||
|
||||
def fake_obs(self) -> common.Observation:
|
||||
example = jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), self.example_spec)
|
||||
return _example_to_obs(_make_batch(example))
|
||||
|
||||
def norm_stats(self, processor_name: str) -> dict[str, _normalize.NormStats]:
|
||||
return _import_norm_stats(self.ckpt_dir, processor_name)
|
||||
|
||||
def set_module(self, module: common.BaseModule, param_path: str) -> _model.Model:
|
||||
"""Creates a new model that uses the same parameters but a different module.
|
||||
|
||||
Args:
|
||||
module: The module to use for the model.
|
||||
param_path: Location of the parameter sub-tree that should be loaded (e.g., decoder).
|
||||
Can include "/" to support nesting.
|
||||
|
||||
Returns:
|
||||
A new model with the parameters loaded from the checkpoint.
|
||||
"""
|
||||
params = self.params
|
||||
for part in param_path.split("/"):
|
||||
if part not in params:
|
||||
raise ValueError(f"{part} not found in the checkpoint. Available keys: {list(params)}")
|
||||
params = params[part]
|
||||
return _model.Model(
|
||||
module=module,
|
||||
params=params,
|
||||
action_dim=self.action_dim,
|
||||
action_horizon=self.action_horizon,
|
||||
max_token_len=self.max_token_len,
|
||||
)
|
||||
|
||||
|
||||
def _load_params(
|
||||
path: pathlib.Path, params_spec: at.PyTree | None = None, sharding: jax.sharding.Sharding | None = None
|
||||
):
|
||||
if sharding is None:
|
||||
sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0])
|
||||
|
||||
def to_restore_args(tree):
|
||||
return jax.tree.map(lambda x: ocp.ArrayRestoreArgs(dtype=x.dtype, sharding=sharding), tree)
|
||||
|
||||
with ocp.PyTreeCheckpointer() as ckptr:
|
||||
if params_spec is None:
|
||||
params_spec = ckptr.metadata(path)["params"]
|
||||
item = {"params": params_spec}
|
||||
return ckptr.restore(
|
||||
path,
|
||||
args=ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=to_restore_args(item),
|
||||
# This is needed to read a partial checkpoint.
|
||||
transforms={},
|
||||
),
|
||||
)["params"]
|
||||
|
||||
|
||||
def _obs_to_example(obs: common.Observation, example_spec: dict) -> dict:
|
||||
def to_uint8(v):
|
||||
return (255.0 * (v + 1.0) / 2.0).astype(jnp.uint8)
|
||||
|
||||
images = {k: to_uint8(v) for k, v in obs.images.items()}
|
||||
image_masks = {f"{k}_mask": v for k, v in obs.image_masks.items()}
|
||||
|
||||
result = {
|
||||
"image": {**images, **image_masks},
|
||||
"state": obs.state,
|
||||
"prompt_tokens": obs.tokenized_prompt,
|
||||
}
|
||||
|
||||
# NOTE(ury): This is used to support the new version with DCT co-training.
|
||||
if "mask_prompt_input" in example_spec:
|
||||
allow_action_diffusion_attention = example_spec["allow_action_diffusion_attention"]
|
||||
mask_ar = example_spec["mask_ar"]
|
||||
|
||||
result = {
|
||||
**result,
|
||||
"mask_prompt_input": obs.tokenized_prompt_mask,
|
||||
# NOTE(ury): These values are likely wrong. Put something for now
|
||||
# to make sure that the model doesn't crash.
|
||||
"allow_action_diffusion_attention": _make_batch(
|
||||
jnp.zeros(allow_action_diffusion_attention.shape, allow_action_diffusion_attention.dtype)
|
||||
),
|
||||
"mask_ar": _make_batch(jnp.ones(mask_ar.shape, mask_ar.dtype)),
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
**result,
|
||||
"mask_input": obs.tokenized_prompt_mask,
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _example_to_obs(example: dict) -> common.Observation:
|
||||
images, image_masks = {}, {}
|
||||
for k, v in example["image"].items():
|
||||
if k.endswith("_mask"):
|
||||
image_masks[k.removesuffix("_mask")] = v
|
||||
else:
|
||||
images[k] = v
|
||||
|
||||
# NOTE(ury): This is used to support the new version with DCT co-training.
|
||||
if "mask_prompt_input" in example:
|
||||
example["mask_input"] = example["mask_prompt_input"]
|
||||
|
||||
return common.Observation.from_dict(
|
||||
{
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": example["state"],
|
||||
"tokenized_prompt": example["prompt_tokens"],
|
||||
"tokenized_prompt_mask": example["mask_input"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _import_norm_stats(ckpt_dir: pathlib.Path | str, processor_name: str) -> dict[str, _normalize.NormStats]:
|
||||
ckpt_dir = pathlib.Path(ckpt_dir).resolve()
|
||||
path = ckpt_dir / "processors" / processor_name
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Processor {processor_name} not found in {ckpt_dir}")
|
||||
|
||||
if not (found_files := list(path.glob("*/norm_stats.msgpack"))):
|
||||
raise FileNotFoundError(f"norm_stats.msgpack not found in {path}")
|
||||
|
||||
outputs = []
|
||||
|
||||
for file in sorted(found_files):
|
||||
with file.open("rb") as f:
|
||||
norm_stats = flax.serialization.msgpack_restore(f.read())
|
||||
|
||||
# This is the new Normalize processor.
|
||||
if "input_norms" in norm_stats:
|
||||
actions = norm_stats["output_norms"]["actions"]
|
||||
outputs.append(_normalize.NormStats(mean=actions["mean"], std=actions["std"]))
|
||||
|
||||
state = norm_stats["input_norms"]["state"]
|
||||
outputs.append(_normalize.NormStats(mean=state["mean"], std=state["std"]))
|
||||
|
||||
# This is to support the old NormalizeActions / NormalizeState processor combo.
|
||||
else:
|
||||
outputs.append(_normalize.NormStats(mean=norm_stats["mean"], std=norm_stats["std"]))
|
||||
|
||||
return {
|
||||
"actions": outputs[0],
|
||||
"state": outputs[1],
|
||||
}
|
||||
|
||||
|
||||
def _make_batch(data: at.PyTree) -> at.PyTree:
|
||||
return jax.tree.map(lambda x: x[jnp.newaxis, ...], data)
|
||||
|
||||
|
||||
def _unbatch(data: at.PyTree) -> at.PyTree:
|
||||
return jax.tree.map(lambda x: x[0, ...], data)
|
||||
47
src/openpi/models/exported_test.py
Normal file
47
src/openpi/models/exported_test.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pathlib
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.exported as exported
|
||||
import openpi.models.model as _model
|
||||
import openpi.models.pi0 as pi0
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
|
||||
|
||||
def test_sample_actions():
|
||||
model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
|
||||
actions = model.sample_actions(jax.random.key(0), model.fake_obs(), num_steps=10)
|
||||
|
||||
assert actions.shape == (1, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_exported_as_pi0():
|
||||
pi_model = exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
|
||||
model = pi_model.set_module(pi0.Module(), param_path="decoder")
|
||||
|
||||
key = jax.random.key(0)
|
||||
obs = model.fake_obs()
|
||||
|
||||
pi_actions = pi_model.sample_actions(key, obs, num_steps=10)
|
||||
actions = model.sample_actions(key, obs, num_steps=10)
|
||||
|
||||
assert pi_actions.shape == (1, model.action_horizon, model.action_dim)
|
||||
assert actions.shape == (1, model.action_horizon, model.action_dim)
|
||||
|
||||
diff = jnp.max(jnp.abs(pi_actions - actions))
|
||||
assert diff < 10.0
|
||||
|
||||
|
||||
def test_convert_to_openpi(tmp_path: pathlib.Path):
|
||||
output_dir = tmp_path / "output"
|
||||
|
||||
exported.convert_to_openpi(
|
||||
"s3://openpi-assets/exported/pi0_aloha_sim/model",
|
||||
"huggingface_aloha_sim_transfer_cube",
|
||||
output_dir,
|
||||
)
|
||||
|
||||
# Make sure that we can load the params and norm stats.
|
||||
_ = _model.restore_params(output_dir / "params")
|
||||
_ = _checkpoints.load_norm_stats(output_dir / "assets")
|
||||
600
src/openpi/models/gemma.py
Normal file
600
src/openpi/models/gemma.py
Normal file
@@ -0,0 +1,600 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""gemma adaptation for Pi, taken from big_vision.
|
||||
|
||||
We follow this einsum axis naming convention:
|
||||
B: batch
|
||||
T: query length
|
||||
S: k/v length
|
||||
N: num query heads
|
||||
K: num k/v heads
|
||||
G: num query heads per k/v head
|
||||
H: head dim
|
||||
D: d_model ("features")
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257_152
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LoRAConfig:
|
||||
rank: int
|
||||
alpha: float
|
||||
dropout: float = 0.0
|
||||
# https://arxiv.org/pdf/2312.03732
|
||||
rslora: bool = False
|
||||
rank_annotation: str = "L"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.rank != int(self.alpha):
|
||||
logging.warning(
|
||||
"Rank and alpha are not the same, this will cause accuracy error when merging LoRA params currently."
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
projection_lora: LoRAConfig | None = None
|
||||
projection_kv_lora: LoRAConfig | None = None
|
||||
output_lora: LoRAConfig | None = None
|
||||
|
||||
|
||||
Variant = Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant: Variant) -> Config:
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "dummy":
|
||||
return Config(
|
||||
width=64,
|
||||
depth=4,
|
||||
mlp_dim=128,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=16,
|
||||
)
|
||||
if variant == "gemma_300m":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
projection_lora=LoRAConfig(rank=64, alpha=64.0),
|
||||
projection_kv_lora=LoRAConfig(rank=64, alpha=64.0),
|
||||
output_lora=LoRAConfig(rank=64, alpha=64.0),
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Einsum(nn.Module):
|
||||
shape: tuple[int, ...]
|
||||
init_fn: nn.initializers.Initializer
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w = self.param("w", self.init_fn, self.shape).astype(dtype)
|
||||
return jnp.einsum(eqn, x, w)
|
||||
|
||||
|
||||
_LORA_A_KEY = "lora_a"
|
||||
_LORA_B_KEY = "lora_b"
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class LoRAEinsum(nn.Module):
|
||||
base: Einsum
|
||||
lora_config: LoRAConfig
|
||||
merge_eqn: str
|
||||
lora_a_init_fn: nn.initializers.Initializer
|
||||
lora_b_init_fn: nn.initializers.Initializer
|
||||
|
||||
def setup(self):
|
||||
nn.share_scope(self, self.base)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn, x, *, deterministic=True):
|
||||
orig_x = x
|
||||
eqn_lora_a, eqn_lora_b = self._get_lora_eqn(eqn, self.merge_eqn)
|
||||
if self.lora_config.dropout > 0.0:
|
||||
x = nn.Dropout(rate=self.lora_config.dropout, deterministic=deterministic)(x)
|
||||
lora_a_shape, lora_b_shape = self._parse_shape(self.merge_eqn)
|
||||
lora_a = self.param(_LORA_A_KEY, self.lora_a_init_fn, lora_a_shape).astype(x.dtype)
|
||||
lora_b = self.param(_LORA_B_KEY, self.lora_b_init_fn, lora_b_shape).astype(x.dtype)
|
||||
lora_a = jnp.einsum(eqn_lora_a, x, lora_a)
|
||||
lora_b = jnp.einsum(eqn_lora_b, lora_a, lora_b)
|
||||
|
||||
# TODO: scaling_value should ideally be a self.variable however currently base model doesn't allow any
|
||||
# auxilary variables.
|
||||
scaling_value = (
|
||||
self.lora_config.alpha / self.lora_config.rank
|
||||
if not self.lora_config.rslora
|
||||
else self.lora_config.alpha / math.sqrt(self.lora_config.rank)
|
||||
)
|
||||
|
||||
return self.base(eqn, orig_x) + lora_b * scaling_value
|
||||
|
||||
def _get_lora_eqn(self, eqn: str, lora_merge_eqn: str) -> tuple[str, str]:
|
||||
"""Figure out lora_a and lora_b eqn from eqn and lora_merge_eqn.
|
||||
input:
|
||||
eqn: x,w->y
|
||||
lora_merge_eqn: lora_a,lora_b->w
|
||||
|
||||
output:
|
||||
lora_a_eqn: x,lora_a->?
|
||||
lora_b_eqn: ?,lora_b->y
|
||||
"""
|
||||
(x_repr, w_repr), y_repr = _parse_einops_eqn(eqn)
|
||||
(lora_a_repr, lora_b_repr), w_repr_p = _parse_einops_eqn(lora_merge_eqn)
|
||||
assert len(w_repr) == len(self.base.shape), f"w_repr={w_repr}, shape={self.base.shape}"
|
||||
assert w_repr == w_repr_p, f"w_repr={w_repr}, w_repr_p={w_repr_p} should be the same."
|
||||
|
||||
# figure out x,lora_a's output annotation by using y and lora_b
|
||||
# the way to do this is to:
|
||||
# 1. remove the common prefix and suffix from lora_b and y
|
||||
# 2. then the ? will be (common prefix) (stripped y) (stripped lora_b)
|
||||
# the equation will look like:
|
||||
# [(prefix) (stripped y) (lora b)], [(prefix) (lora b) (suffix)] -> [(prefix) (y) (suffix)]
|
||||
prefix, _, y_repr_stripped, lora_b_repr_stripped = self._remove_common_prefix_suffix(y_repr, lora_b_repr)
|
||||
lora_intermediate_repr = prefix + y_repr_stripped + lora_b_repr_stripped
|
||||
|
||||
eqn_lora_a_lhs = ", ".join([x_repr, lora_a_repr])
|
||||
eqn_lora_b_lhs = ", ".join([lora_intermediate_repr, lora_b_repr])
|
||||
return eqn_lora_a_lhs + " -> " + lora_intermediate_repr, eqn_lora_b_lhs + " -> " + y_repr
|
||||
|
||||
def _remove_common_prefix_suffix(self, str1, str2):
|
||||
# Get the common prefix
|
||||
prefix = ""
|
||||
for i in range(min(len(str1), len(str2))):
|
||||
if str1[i] == str2[i]:
|
||||
prefix += str1[i]
|
||||
else:
|
||||
break
|
||||
|
||||
# Get the common suffix
|
||||
suffix = ""
|
||||
for i in range(1, min(len(str1), len(str2)) + 1):
|
||||
if str1[-i] == str2[-i]:
|
||||
suffix = str1[-i] + suffix
|
||||
else:
|
||||
break
|
||||
|
||||
return prefix, suffix, str1[len(prefix) : -len(suffix)], str2[len(prefix) : -len(suffix)]
|
||||
|
||||
def _parse_shape(self, lora_merge_eqn: str) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
(lora_lhs_part_0, lora_lhs_part_1), lora_rhs = _parse_einops_eqn(lora_merge_eqn)
|
||||
ann_to_dim = dict(zip(lora_rhs, self.base.shape, strict=True))
|
||||
ann_to_dim[self.lora_config.rank_annotation] = self.lora_config.rank
|
||||
return tuple(ann_to_dim[ann] for ann in lora_lhs_part_0), tuple(ann_to_dim[ann] for ann in lora_lhs_part_1)
|
||||
|
||||
|
||||
def merge_lora_params(lora_params: at.PyTree, get_lora_transform_eqn: Callable[[str], str]) -> at.PyTree:
|
||||
params = lora_params["params"]
|
||||
flattened_params = traverse_util.flatten_dict(params, sep="/")
|
||||
merged_params = {}
|
||||
for k in flattened_params:
|
||||
if _LORA_A_KEY not in k:
|
||||
continue
|
||||
lora_b_key = k.replace(_LORA_A_KEY, _LORA_B_KEY)
|
||||
orig_w_key = k.replace(_LORA_A_KEY, "w")
|
||||
assert lora_b_key in flattened_params
|
||||
assert orig_w_key in flattened_params
|
||||
lora_merge = jnp.einsum(get_lora_transform_eqn(k), flattened_params[k], flattened_params[lora_b_key])
|
||||
# TODO: Currently we don't handling lora scaling value here due to the base model doesn't support auxilary
|
||||
# variables.
|
||||
merged_params[orig_w_key] = flattened_params[orig_w_key] + lora_merge
|
||||
for k in flattened_params:
|
||||
if _LORA_A_KEY in k or _LORA_B_KEY in k:
|
||||
continue
|
||||
if k not in merged_params:
|
||||
merged_params[k] = flattened_params[k]
|
||||
return {"params": traverse_util.unflatten_dict(merged_params, sep="/")}
|
||||
|
||||
|
||||
def _parse_einops_eqn(eqn: str) -> tuple[tuple[str, str], str]:
|
||||
lhs, rhs = eqn.split("->")
|
||||
lhs_parts = lhs.split(",")
|
||||
assert len(lhs_parts) == 2
|
||||
|
||||
def strip_space(s):
|
||||
return s.replace(" ", "")
|
||||
|
||||
lhs_parts[0] = strip_space(lhs_parts[0])
|
||||
lhs_parts[1] = strip_space(lhs_parts[1])
|
||||
rhs = strip_space(rhs)
|
||||
return ((lhs_parts[0], lhs_parts[1]), rhs)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype) # return in original dtype
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, positions, attn_mask, decode: bool): # noqa: FBT001
|
||||
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
||||
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
||||
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
||||
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
||||
|
||||
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
||||
|
||||
qkvs = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is None:
|
||||
continue
|
||||
if config.num_kv_heads == config.num_heads:
|
||||
qkv_einsum = Einsum(
|
||||
shape=(3, config.num_heads, config.width, config.head_dim),
|
||||
name=_name("qkv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
)
|
||||
if config.projection_lora is not None:
|
||||
qkv_einsum = LoRAEinsum(
|
||||
qkv_einsum,
|
||||
config.projection_lora,
|
||||
merge_eqn="3KDL,3KLKH->3KDH",
|
||||
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
|
||||
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
|
||||
)
|
||||
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
||||
else:
|
||||
q_einsum = Einsum(
|
||||
shape=(config.num_heads, config.width, config.head_dim),
|
||||
name=_name("q_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
)
|
||||
if config.projection_lora is not None:
|
||||
q_einsum = LoRAEinsum(
|
||||
q_einsum,
|
||||
config.projection_lora,
|
||||
merge_eqn="NDL,NLNH->NDH",
|
||||
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 2)),
|
||||
)
|
||||
q = q_einsum("BTD,NDH->BTNH", x)
|
||||
kv_einsum = Einsum(
|
||||
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
||||
name=_name("kv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
)
|
||||
if config.projection_kv_lora is not None:
|
||||
kv_einsum = LoRAEinsum(
|
||||
kv_einsum,
|
||||
config.projection_kv_lora,
|
||||
merge_eqn="2KDL,2KLKH->2KDH",
|
||||
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=-3, out_axis=-1, batch_axis=(0, 1, 3)),
|
||||
)
|
||||
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
qkvs.append((q, k, v))
|
||||
|
||||
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
||||
|
||||
q = _apply_rope(q, positions=positions)
|
||||
q *= self.configs[0].head_dim ** -0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions)
|
||||
|
||||
# should still be half-precision here (if input was half-precision)
|
||||
assert q.dtype == k.dtype == v.dtype == dtype
|
||||
|
||||
if decode:
|
||||
if not self.has_variable("cache", "k_cache"):
|
||||
# initial prefill
|
||||
self.put_variable("cache", "k_cache", k)
|
||||
self.put_variable("cache", "v_cache", v)
|
||||
else:
|
||||
# decoding
|
||||
k = jnp.concatenate([self.get_variable("cache", "k_cache"), k], axis=1)
|
||||
v = jnp.concatenate([self.get_variable("cache", "v_cache"), v], axis=1)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
|
||||
out = []
|
||||
start = 0
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
end = start + x.shape[1]
|
||||
out_einsum = Einsum(
|
||||
shape=(config.num_heads, config.head_dim, config.width),
|
||||
name=_name("attn_vec_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
)
|
||||
if config.projection_lora is not None:
|
||||
out_einsum = LoRAEinsum(
|
||||
out_einsum,
|
||||
config.projection_lora,
|
||||
merge_eqn="NHNL,NLD->NHD",
|
||||
lora_a_init_fn=nn.initializers.lecun_normal(in_axis=(-4, -3), out_axis=(-2, -1)),
|
||||
lora_b_init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
)
|
||||
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
||||
start = end
|
||||
else:
|
||||
out.append(None)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, unused_scan_arg, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
||||
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
||||
|
||||
attn = Attention(configs=self.configs, name="attn")
|
||||
|
||||
pre_attn = []
|
||||
for i, x in enumerate(xs):
|
||||
if x is not None:
|
||||
x = RMSNorm(name=_name("pre_attention_norm", i))(x) # noqa: PLW2901
|
||||
pre_attn.append(x)
|
||||
|
||||
post_attn = attn(pre_attn, positions, attn_mask, decode)
|
||||
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
||||
xs = jax.tree.map(lambda x, y: x + y, xs, post_attn)
|
||||
|
||||
out = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
x = RMSNorm(name=_name("pre_ffw_norm", i))(x) # noqa: PLW2901
|
||||
x = FeedForward( # noqa: PLW2901
|
||||
features=config.width,
|
||||
hidden_dim=config.mlp_dim,
|
||||
name=_name("mlp", i),
|
||||
)(x)
|
||||
out.append(x)
|
||||
|
||||
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
||||
xs = jax.tree.map(lambda x, y: x + y, xs, out)
|
||||
|
||||
return xs, unused_scan_arg
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
||||
|
||||
configs: Sequence[Config] # list of configs, one for each expert
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
tokens: at.Int[at.Array, "b t"] | None,
|
||||
# list of token arrays, one for each expert, or None if that expert should not be run
|
||||
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None] | None,
|
||||
positions: at.Int[at.Array, "b t"] | None = None,
|
||||
mask: at.Bool[at.Array, "b t s"] | None = None,
|
||||
decode: bool = False,
|
||||
deterministic: bool = True,
|
||||
) -> at.Float[at.Array, "b t d"] | Sequence[at.Float[at.Array, "b _t _d"] | None]:
|
||||
# all experts must have the same depth
|
||||
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
||||
|
||||
# embedder for first expert only
|
||||
embedder = Embedder(
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
embed_dim=self.configs[0].width,
|
||||
name="embedder",
|
||||
)
|
||||
|
||||
if tokens is not None:
|
||||
# embed only
|
||||
assert embedded is None, "Cannot pass both tokens and embedded"
|
||||
return embedder.encode(tokens).astype(self.embed_dtype)
|
||||
|
||||
assert embedded is not None
|
||||
assert positions is not None
|
||||
assert mask is not None
|
||||
|
||||
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
||||
|
||||
mask = jnp.asarray(mask)[:, None, :, :]
|
||||
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=False,
|
||||
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
|
||||
policy=jax.checkpoint_policies.nothing_saveable,
|
||||
)
|
||||
|
||||
block = nn.scan(
|
||||
block_cls,
|
||||
# cache has axis 1 since we want leading dimension to be batch size.
|
||||
variable_axes={"params": 0, "cache": 1},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.configs[0].depth,
|
||||
)(
|
||||
parent=self.scope.push("layers"),
|
||||
configs=self.configs,
|
||||
dropout=self.dropout,
|
||||
dropout_bdims=self.dropout_bdims,
|
||||
)
|
||||
|
||||
embedded, _ = block(embedded, (), positions, mask, decode, deterministic)
|
||||
|
||||
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
||||
|
||||
return [RMSNorm(name=_name("final_norm", i))(e) if e is not None else e for i, e in enumerate(embedded)]
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
||||
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
||||
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
||||
# here.
|
||||
return res.astype(x.dtype)
|
||||
|
||||
|
||||
def _name(name, i):
|
||||
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
||||
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
||||
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
||||
# and the action expert.
|
||||
if i == 0:
|
||||
return name
|
||||
return f"{name}_{i}"
|
||||
97
src/openpi/models/lora_test.py
Normal file
97
src/openpi/models/lora_test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import chex
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
import openpi.models.gemma as gemma
|
||||
|
||||
|
||||
def get_annotation_to_dim_size() -> dict[str, int]:
|
||||
return {
|
||||
"B": 8,
|
||||
"T": 13,
|
||||
"S": 7,
|
||||
"N": 4,
|
||||
"M": 4,
|
||||
"K": 2,
|
||||
"H": 48,
|
||||
"D": 64,
|
||||
}
|
||||
|
||||
|
||||
def eqn_to_shape(eqn: str, annotation_to_dim_size: dict[str, int]) -> tuple[tuple[int, ...], ...]:
|
||||
(lhs_part_0, lhs_part_1), _ = gemma._parse_einops_eqn(eqn) # noqa: SLF001
|
||||
return tuple(int(ann) if ann.isdigit() else annotation_to_dim_size[ann] for ann in lhs_part_0), tuple(
|
||||
int(ann) if ann.isdigit() else annotation_to_dim_size[ann] for ann in lhs_part_1
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("eqn", "lora_annotation"),
|
||||
[
|
||||
("BSD,3KDH->3BSKH", "3KDL,3KLKH->3KDH"),
|
||||
("BTD,NDH->BTNH", "NDL,NLNH->NDH"),
|
||||
("BSD,2KDH->2BSKH", "2KDL,2KLKH->2KDH"),
|
||||
("BTNH,NHD->BTD", "NHNL,NLD->NHD"),
|
||||
],
|
||||
)
|
||||
def test_lora_einsum_equivalent_to_original(eqn: str, lora_annotation: str):
|
||||
annotation_to_dim_size = get_annotation_to_dim_size()
|
||||
x_shape, w_shape = eqn_to_shape(eqn, annotation_to_dim_size)
|
||||
einsum = gemma.Einsum(shape=w_shape, name="einsum", init_fn=nn.initializers.lecun_normal())
|
||||
lora_einsum = gemma.LoRAEinsum(
|
||||
einsum,
|
||||
gemma.LoRAConfig(rank=4, alpha=4.0),
|
||||
lora_annotation,
|
||||
nn.initializers.zeros_init(),
|
||||
nn.initializers.zeros_init(),
|
||||
)
|
||||
|
||||
x = jax.random.normal(jax.random.key(0), x_shape)
|
||||
|
||||
def module_call(instance, x):
|
||||
return instance(eqn, x)
|
||||
|
||||
einsum_variables = einsum.init(jax.random.key(0), x, method=module_call)
|
||||
lora_einsum_variables = lora_einsum.init(jax.random.key(0), x, method=module_call)
|
||||
# Copy over the weights from the original einsum to the lora einsum since the initialization order is
|
||||
# not the same.
|
||||
lora_einsum_variables["params"]["w"] = einsum_variables["params"]["w"]
|
||||
|
||||
y = einsum.apply(einsum_variables, x, rngs={}, method=module_call)
|
||||
y_lora = lora_einsum.apply(lora_einsum_variables, x, rngs={}, method=module_call)
|
||||
chex.assert_trees_all_close(y, y_lora)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("eqn", "lora_annotation"),
|
||||
[
|
||||
("BSD,3KDH->3BSKH", "3KDL,3KLKH->3KDH"),
|
||||
("BTD,NDH->BTNH", "NDL,NLNH->NDH"),
|
||||
("BSD,2KDH->2BSKH", "2KDL,2KLKH->2KDH"),
|
||||
("BTNH,NHD->BTD", "NHNL,NLD->NHD"),
|
||||
],
|
||||
)
|
||||
def test_lora_einsum_param_merge_works(eqn: str, lora_annotation: str):
|
||||
annotation_to_dim_size = get_annotation_to_dim_size()
|
||||
x_shape, w_shape = eqn_to_shape(eqn, annotation_to_dim_size)
|
||||
einsum = gemma.Einsum(shape=w_shape, name="einsum", init_fn=nn.initializers.lecun_normal())
|
||||
lora_einsum = gemma.LoRAEinsum(
|
||||
einsum,
|
||||
gemma.LoRAConfig(rank=4, alpha=4.0),
|
||||
lora_annotation,
|
||||
nn.initializers.lecun_normal(),
|
||||
nn.initializers.lecun_normal(),
|
||||
)
|
||||
|
||||
x = jax.random.uniform(jax.random.key(0), x_shape)
|
||||
|
||||
def module_call(instance, x):
|
||||
return instance(eqn, x)
|
||||
|
||||
lora_einsum_variables = lora_einsum.init(jax.random.key(0), x, method=module_call)
|
||||
einsum_variables = gemma.merge_lora_params(lora_einsum_variables, lambda x: lora_annotation)
|
||||
|
||||
y = einsum.apply(einsum_variables, x, rngs={}, method=module_call)
|
||||
y_lora = lora_einsum.apply(lora_einsum_variables, x, rngs={}, method=module_call)
|
||||
chex.assert_trees_all_close(y, y_lora, atol=0.001)
|
||||
260
src/openpi/models/model.py
Normal file
260
src/openpi/models/model.py
Normal file
@@ -0,0 +1,260 @@
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import augmax
|
||||
from flax import struct
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import common
|
||||
from openpi.shared import image_tools
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
|
||||
# The model always expects these images
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
|
||||
# This may need change if we release a small model.
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
) -> common.Observation:
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for augmax.
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
transforms = []
|
||||
if "wrist" not in key:
|
||||
height, width = image.shape[1:3]
|
||||
transforms += [
|
||||
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
|
||||
augmax.Resize(width, height),
|
||||
augmax.Rotate((-5, 5)),
|
||||
]
|
||||
transforms += [
|
||||
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
|
||||
]
|
||||
sub_rngs = jax.random.split(rng, image.shape[0])
|
||||
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
|
||||
|
||||
# Back to [-1, 1].
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
|
||||
else:
|
||||
out_masks[key] = jnp.asarray(observation.image_masks[key])
|
||||
|
||||
return common.Observation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
)
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class BaseModel(abc.ABC):
|
||||
# Action space dimension.
|
||||
action_dim: int = struct.field(pytree_node=False)
|
||||
# Action sequence length.
|
||||
action_horizon: int = struct.field(pytree_node=False)
|
||||
# Tokenized prompt maximum length.
|
||||
max_token_len: int = struct.field(pytree_node=False)
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
actions: common.Actions,
|
||||
*,
|
||||
train: bool = False,
|
||||
params: at.Params | None = None,
|
||||
) -> at.Float[at.Array, "*b ah"]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
**sample_kwargs,
|
||||
) -> common.Actions: ...
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class Model(BaseModel):
|
||||
module: common.BaseModule = struct.field(pytree_node=False)
|
||||
params: at.Params | None = None
|
||||
|
||||
def init_params(self, rng: at.KeyArrayLike, observation: common.Observation, actions: common.Actions) -> at.Params:
|
||||
"""Initialize and return the parameters by tracing the module's `compute_loss` function."""
|
||||
preprocess_rng, init_rng = jax.random.split(rng)
|
||||
obs = preprocess_observation(preprocess_rng, observation)
|
||||
|
||||
return self.module.init(init_rng, obs, actions, method=self.module.compute_loss)["params"]
|
||||
|
||||
@at.typecheck
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
actions: common.Actions,
|
||||
params: at.Params | None = None,
|
||||
*,
|
||||
train: bool = False,
|
||||
) -> at.Float[at.Array, ""]:
|
||||
if params is None:
|
||||
if self.params is None:
|
||||
raise ValueError(
|
||||
"No parameters found. Either bind the model to parameters using `set_params` or provide params directly."
|
||||
)
|
||||
params = self.params
|
||||
|
||||
loss_rng, preprocess_rng = jax.random.split(rng)
|
||||
|
||||
obs = preprocess_observation(preprocess_rng, observation, train=train)
|
||||
loss_args = (obs, actions)
|
||||
|
||||
return jnp.mean(
|
||||
self.module.apply({"params": params}, *loss_args, rngs={"loss": loss_rng}, method=self.module.compute_loss) # type: ignore
|
||||
)
|
||||
|
||||
@jax.jit
|
||||
@at.typecheck
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: common.Observation,
|
||||
**sample_kwargs,
|
||||
) -> common.Actions:
|
||||
if self.params is None:
|
||||
raise ValueError(
|
||||
"No parameters found. Bind the model to parameters using `set_params` before calling `sample_actions`."
|
||||
)
|
||||
|
||||
preprocess_rng, sample_rng = jax.random.split(rng)
|
||||
|
||||
obs = preprocess_observation(preprocess_rng, observation)
|
||||
sample_args = (self.action_horizon, self.action_dim, obs)
|
||||
|
||||
actions, _ = self.module.apply(
|
||||
{"params": self.params},
|
||||
*sample_args,
|
||||
rngs={"sample": sample_rng},
|
||||
method=self.module.sample_actions,
|
||||
mutable=["cache"],
|
||||
**sample_kwargs,
|
||||
)
|
||||
return actions
|
||||
|
||||
def set_params(self, params: at.Params) -> "Model":
|
||||
"""Returns a copy of the model bound to `params`."""
|
||||
return dataclasses.replace(self, params=params)
|
||||
|
||||
def fake_obs(self, batch_size: int = 1) -> common.Observation:
|
||||
observation_spec, _ = create_inputs_spec(self, batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), observation_spec)
|
||||
|
||||
def fake_act(self, batch_size: int = 1) -> common.Actions:
|
||||
_, action_spec = create_inputs_spec(self, batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.zeros(x.shape, x.dtype), action_spec)
|
||||
|
||||
|
||||
def restore_params(
|
||||
params_path: pathlib.Path | str,
|
||||
*,
|
||||
dtype: jnp.dtype | None = None,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
) -> at.Params:
|
||||
"""Restores unstructured params PyTree from a checkpoint. This works with checkpoints saved with `save_state` during
|
||||
openpi training (see `training/checkpoints.py`) as well as pre-trained checkpoints released for openpi.
|
||||
"""
|
||||
params_path = pathlib.Path(params_path).resolve()
|
||||
if not params_path.exists():
|
||||
raise FileNotFoundError(f"Model params not found at: {params_path}")
|
||||
|
||||
restore_type = np.ndarray if sharding is None else jax.Array
|
||||
|
||||
with ocp.PyTreeCheckpointer() as ckptr:
|
||||
metadata = ckptr.metadata(params_path)
|
||||
# Use EMA params if they exist, otherwise regular params. See `training.utils.TrainState`.
|
||||
params_name = "ema_params" if metadata.get("ema_params") is not None else "params"
|
||||
item = {params_name: metadata[params_name]}
|
||||
|
||||
return ckptr.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree.map(
|
||||
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
|
||||
),
|
||||
transforms={}, # required to load a partial PyTree (e.g., only params from a full TrainState)
|
||||
),
|
||||
)[params_name]
|
||||
|
||||
|
||||
def create_inputs_spec(model: Model, *, batch_size: int = 1) -> tuple[common.Observation, at.Float[at.Array, "ah ad"]]:
|
||||
image_spec = jax.ShapeDtypeStruct([batch_size, 224, 224, 3], jnp.float32)
|
||||
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
||||
|
||||
with at.disable_typechecking():
|
||||
observation_spec = common.Observation(
|
||||
images={
|
||||
"base_0_rgb": image_spec,
|
||||
"left_wrist_0_rgb": image_spec,
|
||||
"right_wrist_0_rgb": image_spec,
|
||||
},
|
||||
image_masks={
|
||||
"base_0_rgb": image_mask_spec,
|
||||
"left_wrist_0_rgb": image_mask_spec,
|
||||
"right_wrist_0_rgb": image_mask_spec,
|
||||
},
|
||||
state=jax.ShapeDtypeStruct([batch_size, model.action_dim], jnp.float32),
|
||||
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, model.max_token_len], jnp.int32),
|
||||
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, model.max_token_len], jnp.int32),
|
||||
)
|
||||
action_spec = jax.ShapeDtypeStruct([batch_size, model.action_horizon, model.action_dim], jnp.float32)
|
||||
|
||||
return observation_spec, action_spec
|
||||
47
src/openpi/models/model_test.py
Normal file
47
src/openpi/models/model_test.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.models import pi0
|
||||
from openpi.shared import download
|
||||
|
||||
|
||||
def make_from_spec(spec: jax.ShapeDtypeStruct):
|
||||
return jnp.zeros(shape=spec.shape, dtype=spec.dtype)
|
||||
|
||||
|
||||
def create_pi0_model():
|
||||
return _model.Model(module=pi0.Module(), action_dim=24, action_horizon=50, max_token_len=48)
|
||||
|
||||
|
||||
def test_model():
|
||||
model = create_pi0_model()
|
||||
|
||||
batch_size = 2
|
||||
obs, act = model.fake_obs(batch_size), model.fake_act(batch_size)
|
||||
|
||||
rng = jax.random.key(0)
|
||||
model = model.set_params(model.init_params(rng, obs, act))
|
||||
|
||||
loss = model.compute_loss(rng, obs, act)
|
||||
assert loss.shape == ()
|
||||
|
||||
actions = model.sample_actions(rng, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_model_restore():
|
||||
model = create_pi0_model()
|
||||
|
||||
batch_size = 2
|
||||
obs, act = model.fake_obs(batch_size), model.fake_act(batch_size)
|
||||
|
||||
params = _model.restore_params(download.maybe_download("s3://openpi-assets/exported/pi0_aloha_sim/model"))
|
||||
model = model.set_params(params)
|
||||
|
||||
rng = jax.random.key(0)
|
||||
loss = model.compute_loss(rng, obs, act)
|
||||
assert loss.shape == ()
|
||||
|
||||
actions = model.sample_actions(rng, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
306
src/openpi/models/pi0.py
Normal file
306
src/openpi/models/pi0.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import common
|
||||
import openpi.models.gemma as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
|
||||
def make_attn_mask(input_mask, mask_ar):
|
||||
"""Copied from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
|
||||
it and 0 where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
cumsum = jnp.cumsum(mask_ar, axis=1)
|
||||
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
||||
return jnp.logical_and(attn_mask, valid_mask)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def posemb_sincos(
|
||||
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
|
||||
) -> at.Float[at.Array, "b {embedding_dim}"]:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if embedding_dim % 2 != 0:
|
||||
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
|
||||
|
||||
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
sinusoid_input = jnp.einsum(
|
||||
"i,j->ij",
|
||||
pos,
|
||||
1.0 / period * 2 * jnp.pi,
|
||||
precision=jax.lax.Precision.HIGHEST,
|
||||
)
|
||||
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
|
||||
|
||||
|
||||
class Module(common.BaseModule):
|
||||
"""Pi0 module (transfusion-style decoder-only flow matching)."""
|
||||
|
||||
dtype: str = "bfloat16"
|
||||
paligemma_variant: _gemma.Variant = "gemma_2b"
|
||||
action_expert_variant: _gemma.Variant = "gemma_300m"
|
||||
|
||||
@at.typecheck
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
obs: common.Observation,
|
||||
target_actions: common.Actions,
|
||||
*,
|
||||
timestep: at.Float[at.Array, " b"] | None = None,
|
||||
) -> at.Float[at.Array, "b ah"]:
|
||||
batch_size = target_actions.shape[0]
|
||||
|
||||
noise = jax.random.normal(self.make_rng("loss"), target_actions.shape)
|
||||
if timestep is None:
|
||||
timestep = jax.random.beta(self.make_rng("loss"), 1.5, 1, (batch_size,)) * 0.999 + 0.001
|
||||
|
||||
time_expanded = timestep[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * target_actions
|
||||
u_t = noise - target_actions
|
||||
pred = self.forward(obs, x_t, timestep, mode="train")
|
||||
return jnp.mean(jnp.square(pred - u_t), axis=2)
|
||||
|
||||
@at.typecheck
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
obs: common.Observation,
|
||||
*,
|
||||
noise: at.Float[at.Array, "b ah ad"] | None = None,
|
||||
num_steps: int | at.Int[at.Array, ""] = 10,
|
||||
) -> common.Actions:
|
||||
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
|
||||
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
|
||||
dt = -1.0 / num_steps
|
||||
batch_size = obs.state.shape[0]
|
||||
if noise is None:
|
||||
noise = jax.random.normal(self.make_rng("sample"), (batch_size, action_horizon, action_dim))
|
||||
|
||||
# first fill KV cache (in-place)
|
||||
self.forward(obs, None, None, mode="fill_cache")
|
||||
|
||||
@at.typecheck
|
||||
def sample_step(
|
||||
module: Module,
|
||||
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
|
||||
) -> tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]]:
|
||||
x_t, time = carry
|
||||
time_batched = einops.repeat(time, "-> b", b=batch_size)
|
||||
v_t = module.forward(obs, x_t, time_batched, mode="decode")
|
||||
# Euler step
|
||||
x_tilde = x_t + dt * v_t
|
||||
return x_tilde, time + dt
|
||||
|
||||
@at.typecheck
|
||||
def cond_fn(
|
||||
module: Module,
|
||||
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
|
||||
) -> at.Bool[at.Array, ""]:
|
||||
x_t, time = carry
|
||||
# robust to floating-point error
|
||||
return time >= -dt / 2
|
||||
|
||||
time = jnp.array(1.0, dtype=jnp.float32)
|
||||
x_0, _ = nn.while_loop(cond_fn, sample_step, self, (noise, time))
|
||||
return x_0
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def forward(
|
||||
self,
|
||||
obs: common.Observation,
|
||||
noisy_actions: at.Float[at.Array, "b ah ad"] | None,
|
||||
timestep: at.Float[at.Array, " b"] | None,
|
||||
mode: Literal["train", "fill_cache", "decode"],
|
||||
):
|
||||
"""Main forward pass of the transformer. It operates in 3 modes:
|
||||
|
||||
1. mode="train": This is full forward pass, used during training.
|
||||
2. mode="fill_cache": This is used to compute the KV cache for the prefix (image + language inputs).
|
||||
3. mode="decode": This is used to perform a flow matching integration step; it uses the KV cache computed in the
|
||||
fill_cache mode.
|
||||
"""
|
||||
paligemma_scope = self.scope.push("PaliGemma")
|
||||
llm_scope = paligemma_scope.push("llm")
|
||||
img_scope = paligemma_scope.push("img")
|
||||
|
||||
paligemma_config = _gemma.get_config(self.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(self.action_expert_variant)
|
||||
gemma = _gemma.Module(
|
||||
configs=[paligemma_config, action_expert_config],
|
||||
embed_dtype=self.dtype,
|
||||
parent=llm_scope,
|
||||
)
|
||||
siglip = _siglip.Module(
|
||||
num_classes=paligemma_config.width,
|
||||
variant="So400m/14",
|
||||
pool_type="none",
|
||||
scan=True,
|
||||
dtype_mm=self.dtype,
|
||||
parent=img_scope,
|
||||
)
|
||||
|
||||
batch_size = obs.state.shape[0]
|
||||
|
||||
input_mask: list[at.Bool[at.Array, "b s"]] = []
|
||||
ar_mask: list[int] = []
|
||||
|
||||
if mode in ["train", "fill_cache"]:
|
||||
prefix_tokens: list[at.Float[at.Array, "b s emb"]] = []
|
||||
# embed images
|
||||
for name in obs.images:
|
||||
image_tokens, _ = siglip(obs.images[name], train=False)
|
||||
|
||||
prefix_tokens.append(image_tokens)
|
||||
input_mask.append(
|
||||
einops.repeat(
|
||||
obs.image_masks[name],
|
||||
"b -> b s",
|
||||
s=image_tokens.shape[1],
|
||||
)
|
||||
)
|
||||
# image tokens attend to each other
|
||||
ar_mask += [0] * image_tokens.shape[1]
|
||||
|
||||
# add language (aka tokenized inputs)
|
||||
if obs.tokenized_prompt is not None:
|
||||
# run gemma in embed-only mode
|
||||
tokenized_inputs = gemma(tokens=obs.tokenized_prompt, embedded=None)
|
||||
prefix_tokens.append(tokenized_inputs)
|
||||
input_mask.append(obs.tokenized_prompt_mask)
|
||||
# full attention between image and language inputs
|
||||
ar_mask += [0] * tokenized_inputs.shape[1]
|
||||
prefix_tokens = jnp.concatenate(prefix_tokens, axis=1)
|
||||
prefix_len = prefix_tokens.shape[1]
|
||||
|
||||
if mode in ["train", "decode"]:
|
||||
assert noisy_actions is not None
|
||||
|
||||
suffix_tokens: list[at.Float[at.Array, "b s emb"]] = []
|
||||
# add a single state token
|
||||
state_token = nn.Dense(action_expert_config.width, name="state_proj")(obs.state)
|
||||
suffix_tokens.append(state_token[:, None, :])
|
||||
input_mask.append(jnp.ones((batch_size, 1), dtype=jnp.bool_))
|
||||
# image/language inputs do not attend to state or actions
|
||||
ar_mask += [1]
|
||||
|
||||
action_horizon = noisy_actions.shape[1]
|
||||
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = posemb_sincos(timestep, action_expert_config.width, min_period=4e-3, max_period=4.0)
|
||||
# mix timestep + action information using an MLP
|
||||
action_tokens = nn.Dense(action_expert_config.width, name="action_in_proj")(noisy_actions)
|
||||
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=action_horizon)
|
||||
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
|
||||
action_time_tokens = nn.Dense(action_expert_config.width, name="action_time_mlp_in")(action_time_tokens)
|
||||
action_time_tokens = nn.swish(action_time_tokens)
|
||||
action_time_tokens = nn.Dense(action_expert_config.width, name="action_time_mlp_out")(action_time_tokens)
|
||||
# add to input tokens
|
||||
suffix_tokens.append(action_time_tokens)
|
||||
input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
|
||||
# image/language/state inputs do not attend to action tokens
|
||||
ar_mask += [1] + ([0] * (action_horizon - 1))
|
||||
|
||||
suffix_tokens = jnp.concatenate(suffix_tokens, axis=1)
|
||||
suffix_len = suffix_tokens.shape[1]
|
||||
|
||||
if mode == "train":
|
||||
# due to prefix-lm decoding, it is very important that the prefix cannot attend to the suffix
|
||||
assert ar_mask[prefix_len] == 1
|
||||
|
||||
# create attention mask (shared between prefix and suffix)
|
||||
input_mask = jnp.concatenate(input_mask, axis=1)
|
||||
ar_mask = np.array(ar_mask, dtype=np.int32)
|
||||
|
||||
ar_mask = einops.repeat(ar_mask, "s -> b s", b=batch_size)
|
||||
attn_mask = make_attn_mask(input_mask, ar_mask)
|
||||
|
||||
if mode in ["train", "decode"]:
|
||||
out_proj = nn.Dense(noisy_actions.shape[-1], name="action_out_proj")
|
||||
|
||||
if mode == "train":
|
||||
# full forward pass on prefix + suffix at once
|
||||
positions = jnp.cumsum(input_mask, axis=1) - 1
|
||||
_, out = gemma(
|
||||
tokens=None,
|
||||
embedded=[prefix_tokens, suffix_tokens],
|
||||
mask=attn_mask,
|
||||
positions=positions,
|
||||
decode=False,
|
||||
)
|
||||
return out_proj(out[:, -action_horizon:])
|
||||
if mode == "fill_cache":
|
||||
# fill the KV cache using the prefix tokens. this mutates the "cache" variable in place.
|
||||
self.put_variable("cache", "prefix_mask", input_mask.astype(bool))
|
||||
positions = jnp.cumsum(input_mask, axis=-1) - 1
|
||||
gemma(
|
||||
tokens=None,
|
||||
embedded=[prefix_tokens, None],
|
||||
positions=positions,
|
||||
mask=attn_mask,
|
||||
decode=True,
|
||||
)
|
||||
return None
|
||||
if mode == "decode":
|
||||
# decode using the existing KV cache
|
||||
prefix_len = gemma.variables["cache"]["layers"]["attn"]["k_cache"].shape[2]
|
||||
# `prefix_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
|
||||
# prefix tokens
|
||||
prefix_mask = self.get_variable("cache", "prefix_mask")
|
||||
assert prefix_mask.shape == (batch_size, prefix_len)
|
||||
prefix_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_len)
|
||||
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
|
||||
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
|
||||
combined_mask = jnp.concatenate([prefix_mask, attn_mask], axis=-1)
|
||||
assert combined_mask.shape == (
|
||||
batch_size,
|
||||
suffix_len,
|
||||
prefix_len + suffix_len,
|
||||
)
|
||||
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
|
||||
positions = (
|
||||
jnp.sum(self.get_variable("cache", "prefix_mask"), axis=-1)[:, None]
|
||||
+ jnp.cumsum(input_mask, axis=-1)
|
||||
- 1
|
||||
)
|
||||
unused, out = gemma(
|
||||
tokens=None,
|
||||
embedded=[None, suffix_tokens],
|
||||
mask=combined_mask,
|
||||
positions=positions,
|
||||
decode=True,
|
||||
)
|
||||
assert unused is None
|
||||
return out_proj(out[:, -action_horizon:])
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
191
src/openpi/models/pi0_small.py
Normal file
191
src/openpi/models/pi0_small.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import common
|
||||
import openpi.models.transformer as _transformer
|
||||
import openpi.models.vit as _vit
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def posemb_sincos(
|
||||
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
|
||||
) -> at.Float[at.Array, "b {embedding_dim}"]:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if embedding_dim % 2 != 0:
|
||||
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
|
||||
|
||||
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
sinusoid_input = jnp.einsum(
|
||||
"i,j->ij",
|
||||
pos,
|
||||
1.0 / period * 2 * jnp.pi,
|
||||
precision=jax.lax.Precision.HIGHEST,
|
||||
)
|
||||
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
|
||||
|
||||
|
||||
class ViTEncoder(nn.Module):
|
||||
"""ViT encoder from the Google vision_transformer codebase."""
|
||||
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(self, image: at.Float[at.Array, "b h w c"]) -> at.Float[at.Array, "b seq emb"]:
|
||||
vit = _vit.VisionTransformer(
|
||||
name="VisionTransformer",
|
||||
dtype=self.dtype,
|
||||
# Removes class token.
|
||||
num_classes=None,
|
||||
classifier="unpooled",
|
||||
# R26+ViT-S_32 config.
|
||||
patches=ml_collections.ConfigDict({"size": (1, 1)}),
|
||||
transformer=ml_collections.ConfigDict({"mlp_dim": 1536, "num_heads": 6, "num_layers": 12}),
|
||||
hidden_size=384,
|
||||
resnet=ml_collections.ConfigDict({"num_layers": (2, 2, 2, 2), "width_factor": 1}),
|
||||
)
|
||||
|
||||
# VisionTransformer expects images in [0, 1] range.
|
||||
image = (image + 1) / 2
|
||||
return vit(image, train=False)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer encoder that combines ViTEncoders for each image, plus state information."""
|
||||
|
||||
variant: _transformer.Variant = "small"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(self, obs: common.Observation) -> _transformer.TokenSequence:
|
||||
transformer, embed_dim = _transformer.get_variant(self.variant, dtype=self.dtype)
|
||||
|
||||
image_tokens: list[_transformer.TokenSequence] = []
|
||||
for name in obs.images:
|
||||
zimg = ViTEncoder(name=f"backbone_{name}", dtype=self.dtype)(obs.images[name])
|
||||
zimg = nn.Dense(embed_dim, name=f"proj_{name}")(zimg)
|
||||
posemb = self.param(f"posemb_image_{name}", nn.initializers.normal(0.02), (embed_dim,))
|
||||
image_tokens.append(
|
||||
_transformer.TokenSequence(
|
||||
tokens=zimg,
|
||||
pos=jnp.broadcast_to(posemb, zimg.shape),
|
||||
mask=jnp.broadcast_to(obs.image_masks[name][:, None], zimg.shape[:-1]),
|
||||
)
|
||||
)
|
||||
|
||||
state_token = _transformer.TokenSequence(
|
||||
tokens=nn.Dense(embed_dim, name="state_proj")(obs.state)[:, None],
|
||||
pos=self.param("posemb_state", nn.initializers.normal(0.02), (embed_dim,))[None],
|
||||
)
|
||||
|
||||
input_tokens = _transformer.TokenSequence.concatenate(*image_tokens, state_token)
|
||||
|
||||
return transformer(input_tokens)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
variant: _transformer.Variant = "small"
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
noisy_actions: at.Float[at.Array, "b ah ad"],
|
||||
timestep: at.Float[at.Array, " b"],
|
||||
cond_tokens: _transformer.TokenSequence,
|
||||
) -> at.Float[at.Array, "b ah ad"]:
|
||||
transformer, embed_dim = _transformer.get_variant(self.variant, dtype=self.dtype)
|
||||
|
||||
tokens = _transformer.TokenSequence(
|
||||
# project actions to embedding dimension
|
||||
tokens=nn.Dense(embed_dim, name="in_proj")(noisy_actions),
|
||||
# use learned positional embedding for actions
|
||||
pos=self.param("posemb_actions", nn.initializers.normal(0.02), (noisy_actions.shape[1], embed_dim)),
|
||||
)
|
||||
|
||||
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = posemb_sincos(timestep, embed_dim, min_period=4e-3, max_period=4.0)
|
||||
# time MLP
|
||||
time_emb = nn.Dense(embed_dim, name="time_mlp_in")(time_emb)
|
||||
time_emb = nn.swish(time_emb)
|
||||
time_emb = nn.Dense(embed_dim, name="time_mlp_out")(time_emb)
|
||||
|
||||
output_tokens = transformer(tokens, xattn_cond=cond_tokens, adaln_cond=time_emb)
|
||||
return nn.Dense(noisy_actions.shape[-1], name="out_proj")(output_tokens.tokens)
|
||||
|
||||
|
||||
class Module(common.BaseModule):
|
||||
encoder: Encoder = Encoder()
|
||||
decoder: Decoder = Decoder()
|
||||
|
||||
@at.typecheck
|
||||
@override
|
||||
def compute_loss(
|
||||
self,
|
||||
obs: common.Observation,
|
||||
target_actions: common.Actions,
|
||||
*,
|
||||
timestep: at.Float[at.Array, " b"] | None = None,
|
||||
) -> at.Float[at.Array, "b ah"]:
|
||||
batch_size = target_actions.shape[0]
|
||||
|
||||
noise = jax.random.normal(self.make_rng("loss"), target_actions.shape)
|
||||
if timestep is None:
|
||||
timestep = jax.random.beta(self.make_rng("loss"), 1.5, 1, (batch_size,)) * 0.999 + 0.001
|
||||
|
||||
time_expanded = timestep[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * target_actions
|
||||
u_t = noise - target_actions
|
||||
pred = self.decoder(x_t, timestep, self.encoder(obs))
|
||||
return jnp.mean(jnp.square(pred - u_t), axis=2)
|
||||
|
||||
@at.typecheck
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
action_horizon: int,
|
||||
action_dim: int,
|
||||
obs: common.Observation,
|
||||
*,
|
||||
noise: at.Float[at.Array, "b ah ad"] | None = None,
|
||||
num_steps: int = 10,
|
||||
) -> common.Actions:
|
||||
dt = -1.0 / num_steps
|
||||
batch_size = obs.state.shape[0]
|
||||
if noise is None:
|
||||
noise = jax.random.normal(self.make_rng("sample"), (batch_size, action_horizon, action_dim))
|
||||
|
||||
cond_tokens = self.encoder(obs)
|
||||
|
||||
@at.typecheck
|
||||
def sample_step(
|
||||
module: Module,
|
||||
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
|
||||
) -> tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]]:
|
||||
x_t, time = carry
|
||||
time_batched = einops.repeat(time, "-> b", b=batch_size)
|
||||
v_t = module.decoder(x_t, time_batched, cond_tokens)
|
||||
# Euler step
|
||||
x_tilde = x_t + dt * v_t
|
||||
return x_tilde, time + dt
|
||||
|
||||
@at.typecheck
|
||||
def cond_fn(
|
||||
module: Module,
|
||||
carry: tuple[at.Float[at.Array, "b ah ad"], at.Float[at.Array, ""]],
|
||||
) -> at.Bool[at.Array, ""]:
|
||||
x_t, time = carry
|
||||
# robust to floating-point error
|
||||
return time >= -dt / 2
|
||||
|
||||
time = jnp.array(1.0, dtype=jnp.float32)
|
||||
x_0, _ = nn.while_loop(cond_fn, sample_step, self, (noise, time))
|
||||
return x_0
|
||||
82
src/openpi/models/resnet.py
Normal file
82
src/openpi/models/resnet.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Copyright 2024 Google LLC.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ResNet implementation copied from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_resnet.py."""
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
from flax import linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def weight_standardize(w, axis, eps):
|
||||
"""Subtracts mean and divides by standard deviation."""
|
||||
w = w - jnp.mean(w, axis=axis)
|
||||
return w / (jnp.std(w, axis=axis) + eps)
|
||||
|
||||
|
||||
class StdConv(nn.Conv):
|
||||
"""Convolution with weight standardization."""
|
||||
|
||||
def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
|
||||
param = super().param(name, init_fn, *init_args)
|
||||
if name == "kernel":
|
||||
param = weight_standardize(param, axis=[0, 1, 2], eps=1e-5)
|
||||
return param
|
||||
|
||||
|
||||
class ResidualUnit(nn.Module):
|
||||
"""Bottleneck ResNet block."""
|
||||
|
||||
features: int
|
||||
strides: Sequence[int] = (1, 1)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
needs_projection = x.shape[-1] != self.features * 4 or self.strides != (1, 1)
|
||||
|
||||
residual = x
|
||||
if needs_projection:
|
||||
residual = StdConv(
|
||||
features=self.features * 4, kernel_size=(1, 1), strides=self.strides, use_bias=False, name="conv_proj"
|
||||
)(residual)
|
||||
residual = nn.GroupNorm(name="gn_proj")(residual)
|
||||
|
||||
y = StdConv(features=self.features, kernel_size=(1, 1), use_bias=False, name="conv1")(x)
|
||||
y = nn.GroupNorm(name="gn1")(y)
|
||||
y = nn.relu(y)
|
||||
y = StdConv(features=self.features, kernel_size=(3, 3), strides=self.strides, use_bias=False, name="conv2")(y)
|
||||
y = nn.GroupNorm(name="gn2")(y)
|
||||
y = nn.relu(y)
|
||||
y = StdConv(features=self.features * 4, kernel_size=(1, 1), use_bias=False, name="conv3")(y)
|
||||
|
||||
y = nn.GroupNorm(name="gn3", scale_init=nn.initializers.zeros)(y)
|
||||
return nn.relu(residual + y)
|
||||
|
||||
|
||||
class ResNetStage(nn.Module):
|
||||
"""A ResNet stage."""
|
||||
|
||||
block_size: Sequence[int]
|
||||
nout: int
|
||||
first_stride: Sequence[int]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
x = ResidualUnit(self.nout, strides=self.first_stride, name="unit1")(x)
|
||||
for i in range(1, self.block_size):
|
||||
x = ResidualUnit(self.nout, strides=(1, 1), name=f"unit{i + 1}")(x)
|
||||
return x
|
||||
372
src/openpi/models/siglip.py
Normal file
372
src/openpi/models/siglip.py
Normal file
@@ -0,0 +1,372 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
|
||||
"""Follows the MoCo v3 logic."""
|
||||
y, x = jnp.mgrid[:h, :w]
|
||||
|
||||
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
|
||||
omega = jnp.arange(width // 4) / (width // 4 - 1)
|
||||
omega = 1.0 / (temperature**omega)
|
||||
y = jnp.einsum("m,d->md", y.flatten(), omega)
|
||||
x = jnp.einsum("m,d->md", x.flatten(), omega)
|
||||
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
|
||||
return jnp.asarray(pe, dtype)[None, :, :]
|
||||
|
||||
|
||||
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
|
||||
if typ == "learn":
|
||||
return self.param(
|
||||
name,
|
||||
nn.initializers.normal(stddev=1 / np.sqrt(width)),
|
||||
(1, np.prod(seqshape), width),
|
||||
dtype,
|
||||
)
|
||||
if typ == "sincos2d":
|
||||
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
|
||||
raise ValueError(f"Unknown posemb type: {typ}")
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
inits = {
|
||||
"kernel_init": nn.initializers.xavier_uniform(),
|
||||
"bias_init": nn.initializers.normal(stddev=1e-6),
|
||||
}
|
||||
|
||||
_, _, d = x.shape # n,l,d
|
||||
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout)(x, deterministic)
|
||||
return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Single transformer encoder block (MHSA + MLP)."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["sa"] = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype_mm,
|
||||
)(y, y)
|
||||
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+sa"] = x + y
|
||||
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["mlp"] = MlpBlock(
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout=self.dropout,
|
||||
dtype_mm=self.dtype_mm,
|
||||
)(y, deterministic)
|
||||
y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+mlp"] = x + y
|
||||
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
||||
return x, out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation."""
|
||||
|
||||
depth: int
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
scan: bool = False
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
|
||||
if self.scan:
|
||||
block = nn.remat(
|
||||
Encoder1DBlock,
|
||||
prevent_cse=False,
|
||||
static_argnums=(2,), # 0=self, 2=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
|
||||
)
|
||||
x, scan_out = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.depth,
|
||||
)(
|
||||
name="encoderblock",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)(x, deterministic)
|
||||
for lyr in range(self.depth):
|
||||
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
|
||||
else:
|
||||
# Input Encoder
|
||||
for lyr in range(self.depth):
|
||||
block_cur = Encoder1DBlock(
|
||||
name=f"encoderblock_{lyr}",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
|
||||
out["pre_ln"] = x # Alias for last block, but without the number in it.
|
||||
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
|
||||
|
||||
|
||||
class MAPHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
n, _, d = x.shape # n,l,d
|
||||
probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
|
||||
probe = jnp.tile(probe, [n, 1, 1])
|
||||
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
dtype=self.dtype_mm,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
)(probe, x)
|
||||
|
||||
# TODO: dropout on head?
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class _Module(nn.Module):
|
||||
"""ViT model."""
|
||||
|
||||
num_classes: int | None = None
|
||||
patch_size: Sequence[int] = (16, 16)
|
||||
width: int = 768
|
||||
depth: int = 12
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
posemb: str = "learn" # Can also be "sincos2d"
|
||||
rep_size: int | bool = False
|
||||
dropout: float = 0.0
|
||||
pool_type: str = "gap" # Can also be "map" or "tok"
|
||||
head_zeroinit: bool = True
|
||||
scan: bool = False
|
||||
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, image, *, train=False):
|
||||
out = {}
|
||||
|
||||
# Kevin edit: do patch extraction and posemb in float32,
|
||||
# because I feel like it's a bit safer.
|
||||
image = jnp.asarray(image, jnp.float32)
|
||||
|
||||
# Patch extraction
|
||||
x = out["stem"] = nn.Conv(
|
||||
self.width,
|
||||
self.patch_size,
|
||||
strides=self.patch_size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
dtype=jnp.float32,
|
||||
)(image)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# Add posemb before adding extra token.
|
||||
x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
|
||||
|
||||
if self.pool_type == "tok":
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
|
||||
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
|
||||
|
||||
n, _, c = x.shape # n,l,d
|
||||
x = nn.Dropout(rate=self.dropout)(x, not train)
|
||||
|
||||
# Kevin edit: now cast back to dtype_mm (potentially half precision)
|
||||
x = x.astype(self.dtype_mm)
|
||||
|
||||
x, out["encoder"] = Encoder(
|
||||
depth=self.depth,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scan=self.scan,
|
||||
remat_policy=self.remat_policy,
|
||||
dtype_mm=self.dtype_mm,
|
||||
name="Transformer",
|
||||
)(x, deterministic=not train)
|
||||
encoded = out["encoded"] = x
|
||||
|
||||
if self.pool_type == "map":
|
||||
x = out["head_input"] = MAPHead(
|
||||
num_heads=self.num_heads,
|
||||
mlp_dim=self.mlp_dim,
|
||||
dtype=self.dtype_mm,
|
||||
)(x)
|
||||
elif self.pool_type == "gap":
|
||||
x = out["head_input"] = jnp.mean(x, axis=1)
|
||||
elif self.pool_type == "0":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
elif self.pool_type == "tok":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
encoded = encoded[:, 1:]
|
||||
elif self.pool_type == "none":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
|
||||
|
||||
x_2d = jnp.reshape(encoded, [n, h, w, -1])
|
||||
|
||||
if self.rep_size:
|
||||
rep_size = self.width if self.rep_size is True else self.rep_size
|
||||
hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
|
||||
# NOTE: In the past we did not include tanh in pre_logits.
|
||||
# For few-shot, it should not matter much, as it whitens anyways.
|
||||
x_2d = nn.tanh(hid(x_2d))
|
||||
x = nn.tanh(hid(x))
|
||||
|
||||
out["pre_logits_2d"] = x_2d
|
||||
out["pre_logits"] = x
|
||||
|
||||
if self.num_classes:
|
||||
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
||||
head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
|
||||
x_2d = out["logits_2d"] = head(x_2d)
|
||||
x = out["logits"] = head(x)
|
||||
|
||||
return x, out
|
||||
|
||||
|
||||
def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
|
||||
"""Factory function, because linen really don't like what I'm doing!"""
|
||||
return _Module(num_classes, **{**decode_variant(variant), **kw})
|
||||
|
||||
|
||||
def decode_variant(variant):
|
||||
"""Converts a string like "B" or "B/32" into a params dict."""
|
||||
if variant is None:
|
||||
return {}
|
||||
|
||||
v, patch = variant, {}
|
||||
if "/" in variant:
|
||||
v, patch = variant.split("/")
|
||||
patch = {"patch_size": (int(patch), int(patch))}
|
||||
|
||||
return {
|
||||
# pylint:disable=line-too-long
|
||||
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
|
||||
"width": {
|
||||
"mu": 32,
|
||||
"Ti": 192,
|
||||
"S": 384,
|
||||
"M": 512,
|
||||
"B": 768,
|
||||
"L": 1024,
|
||||
"So400m": 1152,
|
||||
"H": 1280,
|
||||
"g": 1408,
|
||||
"g-opt": 1536,
|
||||
"G": 1664,
|
||||
"G-opt": 1536,
|
||||
"e": 1792,
|
||||
}[v],
|
||||
"depth": {
|
||||
"mu": 1,
|
||||
"Ti": 12,
|
||||
"S": 12,
|
||||
"M": 12,
|
||||
"B": 12,
|
||||
"L": 24,
|
||||
"So400m": 27,
|
||||
"H": 32,
|
||||
"g": 40,
|
||||
"g-opt": 40,
|
||||
"G": 48,
|
||||
"G-opt": 48,
|
||||
"e": 56,
|
||||
}[v],
|
||||
"mlp_dim": {
|
||||
"mu": 128,
|
||||
"Ti": 768,
|
||||
"S": 1536,
|
||||
"M": 2048,
|
||||
"B": 3072,
|
||||
"L": 4096,
|
||||
"So400m": 4304,
|
||||
"H": 5120,
|
||||
"g": 6144,
|
||||
"g-opt": 6144,
|
||||
"G": 8192,
|
||||
"G-opt": 8192,
|
||||
"e": 15360,
|
||||
}[v],
|
||||
"num_heads": {
|
||||
"mu": 2,
|
||||
"Ti": 3,
|
||||
"S": 6,
|
||||
"M": 8,
|
||||
"B": 12,
|
||||
"L": 16,
|
||||
"So400m": 16,
|
||||
"H": 16,
|
||||
"g": 16,
|
||||
"g-opt": 16,
|
||||
"G": 16,
|
||||
"G-opt": 16,
|
||||
"e": 16,
|
||||
}[v],
|
||||
# pylint:enable=line-too-long
|
||||
**patch,
|
||||
}
|
||||
51
src/openpi/models/tokenizer.py
Normal file
51
src/openpi/models/tokenizer.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import abc
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
from typing_extensions import override
|
||||
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
class Tokenizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def tokenize(self, batch: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Tokenize a batch of prompts.
|
||||
|
||||
Args:
|
||||
batch: A batch of text prompts to tokenize.
|
||||
|
||||
Returns:
|
||||
A tuple containing the tokenized prompts and the corresponding masks.
|
||||
"""
|
||||
|
||||
|
||||
class PaligemmaTokenizer(Tokenizer):
|
||||
def __init__(self, max_len: int = 48):
|
||||
self._max_len = max_len
|
||||
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
@override
|
||||
def tokenize(self, batch: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
batch_tokens, batch_masks = [], []
|
||||
|
||||
for text in batch:
|
||||
cleaned_text = text.lower().strip().replace("_", " ").replace("\n", " ")
|
||||
# tokenize "\n" separately as the "start of answer" token
|
||||
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [0] * (self._max_len - tokens_len)
|
||||
mask = [1] * tokens_len + padding
|
||||
tokens = tokens + padding
|
||||
else:
|
||||
tokens = tokens[: self._max_len]
|
||||
mask = [1] * self._max_len
|
||||
|
||||
batch_tokens.append(tokens)
|
||||
batch_masks.append(mask)
|
||||
|
||||
return np.array(batch_tokens), np.array(batch_masks)
|
||||
9
src/openpi/models/tokenizer_test.py
Normal file
9
src/openpi/models/tokenizer_test.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from openpi.models import tokenizer as _tokenizer
|
||||
|
||||
|
||||
def test_tokenize():
|
||||
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
|
||||
tokens, masks = tokenizer.tokenize(["Hello, world!", "This is a test"])
|
||||
|
||||
assert tokens.shape == (2, 10)
|
||||
assert masks.shape == (2, 10)
|
||||
440
src/openpi/models/transformer.py
Normal file
440
src/openpi/models/transformer.py
Normal file
@@ -0,0 +1,440 @@
|
||||
from collections.abc import Callable
|
||||
import enum
|
||||
import functools as ft
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
from flax import struct
|
||||
import flax.linen as nn
|
||||
from flax.linen import dtypes
|
||||
import jax.ad_checkpoint
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AFTER_ATTN_CHECKPOINT_NAME = "after_attn"
|
||||
AFTER_XATTN_CHECKPOINT_NAME = "after_xattn"
|
||||
QKV_CHECKPOINT_NAME = "qkv"
|
||||
|
||||
|
||||
def _custom_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
bias=None,
|
||||
mask=None,
|
||||
broadcast_dropout: bool = True, # noqa
|
||||
dropout_rng=None,
|
||||
dropout_rate: float = 0.0,
|
||||
deterministic: bool = False, # noqa
|
||||
dtype=None,
|
||||
precision=None,
|
||||
module=None,
|
||||
):
|
||||
"""Mostly copied from nn.dot_product_attention, except for adding checkpointing logic, and enforcing float32 logits
|
||||
for stability.
|
||||
"""
|
||||
assert module is None
|
||||
assert dropout_rate == 0.0
|
||||
assert dropout_rng is None
|
||||
assert bias is None
|
||||
|
||||
query, key, value = dtypes.promote_dtype(query, key, value, dtype=dtype)
|
||||
|
||||
# save post-projection query, key, value for backward pass
|
||||
query = jax.ad_checkpoint.checkpoint_name(query, QKV_CHECKPOINT_NAME)
|
||||
key = jax.ad_checkpoint.checkpoint_name(key, QKV_CHECKPOINT_NAME)
|
||||
value = jax.ad_checkpoint.checkpoint_name(value, QKV_CHECKPOINT_NAME)
|
||||
|
||||
dtype = query.dtype
|
||||
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
|
||||
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
|
||||
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
|
||||
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
|
||||
|
||||
# calculate attention matrix
|
||||
depth = query.shape[-1]
|
||||
query = query / jnp.sqrt(depth).astype(dtype)
|
||||
assert query.dtype == dtype
|
||||
|
||||
# calculate logits in float32 for stability
|
||||
logits = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision, preferred_element_type=jnp.float32)
|
||||
# apply attention mask
|
||||
if mask is not None:
|
||||
big_neg = jnp.finfo(jnp.float32).min
|
||||
logits = jnp.where(mask, logits, big_neg)
|
||||
|
||||
# normalize the attention weights and cast back to the original dtype (if not float32)
|
||||
attn_weights = jax.nn.softmax(logits).astype(dtype)
|
||||
|
||||
# return weighted sum over values for each query position
|
||||
out = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value, precision=precision)
|
||||
|
||||
assert out.dtype == dtype
|
||||
return out
|
||||
|
||||
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class TokenSequence:
|
||||
"""Holds a sequence of tokens alongside positional information."""
|
||||
|
||||
tokens: at.Float[at.ArrayLike, "b seq emb"]
|
||||
# pos may or may not have a batch dimension
|
||||
pos: at.Float[at.Array, "b seq emb"] | at.Float[at.Array, "seq emb"]
|
||||
# optional masking information
|
||||
mask: at.Bool[at.Array, "b seq"] | None = None
|
||||
|
||||
def __len__(self):
|
||||
return self.tokens.shape[1]
|
||||
|
||||
@property
|
||||
def emb_dim(self):
|
||||
return self.tokens.shape[-1]
|
||||
|
||||
@staticmethod
|
||||
def concatenate(*sequences: "TokenSequence") -> "TokenSequence":
|
||||
"""Concatenates multiple sequences along the sequence dimension."""
|
||||
tokens = jnp.concatenate([seq.tokens for seq in sequences], axis=1)
|
||||
# if any sequence's pos has a batch dimension, broadcast the others to have one
|
||||
if any(seq.pos.ndim == 3 for seq in sequences):
|
||||
batch_size = next(seq.pos.shape[0] for seq in sequences if seq.pos.ndim == 3)
|
||||
pos = jnp.concatenate(
|
||||
[
|
||||
seq.pos if seq.pos.ndim == 3 else jnp.broadcast_to(seq.pos, (batch_size, *seq.pos.shape))
|
||||
for seq in sequences
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
pos = jnp.concatenate([seq.pos for seq in sequences], axis=0)
|
||||
|
||||
# if any sequence has a mask, create True masks for the others
|
||||
if any(seq.mask is not None for seq in sequences):
|
||||
mask = jnp.concatenate(
|
||||
[
|
||||
seq.mask
|
||||
if seq.mask is not None
|
||||
else jnp.ones((seq.tokens.shape[0], seq.tokens.shape[1]), dtype=jnp.bool_)
|
||||
for seq in sequences
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
return TokenSequence(tokens=tokens, pos=pos, mask=mask)
|
||||
|
||||
|
||||
class PosembStrategy(enum.Enum):
|
||||
"""Controls how positional embeddings are incorporated into the transformer. Configured separately for the
|
||||
input sequence and the cross-attention sequence. Note that for cross-attention, ADD_AT_ATTN and
|
||||
ADD_AT_BEGINNING are very similar, since the key and value token sequences are the same for every
|
||||
attention operation. The only difference is that ADD_AT_ATTN adds the positional embeddings to the key
|
||||
sequence only, while ADD_AT_BEGINNING adds them to both the key and value sequences.
|
||||
|
||||
NONE:
|
||||
Ignore positional embeddings.
|
||||
ADD_AT_BEGINNING:
|
||||
Adds the positional embeddings to the token sequence at the beginning of the transformer call.
|
||||
ADD_AT_ATTN:
|
||||
Adds the positional embeddings to the query and key (but not value) sequences at every attention
|
||||
operation.
|
||||
"""
|
||||
|
||||
NONE = enum.auto()
|
||||
ADD_AT_BEGINNING = enum.auto()
|
||||
ADD_AT_ATTN = enum.auto()
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""Implements either self-attention (if q == kv) or cross-attention (if q != kv)."""
|
||||
|
||||
num_heads: int
|
||||
normalize_qk: bool = True
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
q: at.Float[at.Array, "b q_seq q_emb"],
|
||||
kv: at.Float[at.Array, "b kv_seq kv_emb"],
|
||||
q_pos: at.Float[at.Array, "*bq q_seq q_emb"],
|
||||
kv_pos: at.Float[at.Array, "*bkv kv_seq kv_emb"],
|
||||
mask: at.Bool[at.Array, "b q_seq kv_seq"] | None = None,
|
||||
dtype: at.DTypeLike,
|
||||
) -> at.Float[at.Array, "b q_seq q_emb"]:
|
||||
# broadcast mask to have a head dimension
|
||||
if mask is not None:
|
||||
mask = einops.repeat(mask, "b q_seq kv_seq -> b n q_seq kv_seq", n=self.num_heads)
|
||||
# we add posembs to queries and keys, but not values
|
||||
q = q + q_pos
|
||||
k = kv + kv_pos
|
||||
v = kv
|
||||
return nn.MultiHeadAttention(
|
||||
num_heads=self.num_heads,
|
||||
normalize_qk=self.normalize_qk,
|
||||
use_bias=False,
|
||||
kernel_init=nn.initializers.lecun_normal(),
|
||||
attention_fn=_custom_dot_product_attention,
|
||||
dtype=dtype,
|
||||
)(q, k, v, mask=mask)
|
||||
|
||||
|
||||
class MLPBlock(nn.Module):
|
||||
dim: int
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(self, x: at.Float[at.Array, "b seq emb"], *, dtype: at.DTypeLike) -> at.Float[at.Array, "b seq emb"]:
|
||||
embed_dim = x.shape[-1]
|
||||
# SwiGLU MLP.
|
||||
# fuse the first 2 matmuls into one in case it's more efficient
|
||||
out = nn.DenseGeneral((2, self.dim), use_bias=False, kernel_init=nn.initializers.lecun_normal(), dtype=dtype)(x)
|
||||
gating, hidden = einops.rearrange(out, "b seq n emb -> n b seq emb")
|
||||
return nn.Dense(embed_dim, use_bias=False, kernel_init=nn.initializers.lecun_normal(), dtype=dtype)(
|
||||
nn.swish(gating) * hidden
|
||||
)
|
||||
|
||||
|
||||
class AdaLNGeneral(nn.Module):
|
||||
"""Generalized LayerNorm module, optionally adaptive based on conditioning information.
|
||||
|
||||
If `cond` is None, applies standard LayerNorm with learned scale and bias. If `cond` is not None, applies
|
||||
adaptive LayerNorm (AdaLN):
|
||||
|
||||
>>> out = LayerNorm(x) * (1 + scale) + shift
|
||||
|
||||
Where `scale` and `shift` are learned from conditioning information and initialized to always be 0 (so
|
||||
that the output is initially equal to LayerNorm(x)), and LayerNorm here is the version without learned
|
||||
parameters.
|
||||
|
||||
If `fn` is not None, this module applies normalization, `fn`, and then a residual connection. For example,
|
||||
with `cond == None`:
|
||||
|
||||
>>> out = x + fn(LayerNorm(x))
|
||||
|
||||
With `cond != None`, this becomes AdaLNZero (from the DiT paper):
|
||||
|
||||
>>> out = x + gate * fn(LayerNorm(x) * (1 + scale) + shift)
|
||||
|
||||
where `gate`, `scale`, and `shift` are once again initialized to always be 0, so the output is initially
|
||||
equal to the input.
|
||||
"""
|
||||
|
||||
fn: Callable | None = None
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
x: at.Float[at.Array, "b seq emb"],
|
||||
cond: at.Float[at.Array, "b cond_emb"] | at.Float[at.Array, "b seq cond_emb"] | None = None,
|
||||
*,
|
||||
dtype: at.DTypeLike,
|
||||
) -> at.Float[at.Array, "b seq emb"]:
|
||||
if cond is None:
|
||||
if self.fn is None:
|
||||
return nn.LayerNorm(dtype=dtype)(x)
|
||||
return x + self.fn(nn.LayerNorm(dtype=dtype)(x))
|
||||
# number of learned AdaLN vectors
|
||||
n = 2 if self.fn is None else 3
|
||||
adaln = nn.DenseGeneral(
|
||||
features=(n, x.shape[-1]),
|
||||
kernel_init=nn.zeros,
|
||||
dtype=dtype,
|
||||
)(nn.swish(cond))
|
||||
if cond.ndim == 2:
|
||||
adaln = einops.rearrange(adaln, "b n emb -> n b 1 emb")
|
||||
elif cond.ndim == 3:
|
||||
adaln = einops.rearrange(adaln, "b seq n emb -> n b seq emb")
|
||||
else:
|
||||
raise ValueError(f"Invalid number of dimensions for cond: {cond.ndim}")
|
||||
|
||||
modulated = nn.LayerNorm(use_bias=False, use_scale=False, dtype=dtype)(x) * (1 + adaln[0]) + adaln[1]
|
||||
|
||||
if self.fn is None:
|
||||
return modulated
|
||||
return x + adaln[2] * self.fn(modulated)
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""Transformer block (no attention mask) with optional AdaLN and cross-attention conditioning."""
|
||||
|
||||
attn: AttentionBlock
|
||||
mlp: MLPBlock
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
x: TokenSequence,
|
||||
xattn_cond: TokenSequence | None = None,
|
||||
adaln_cond: at.Float[at.Array, "b adaln_emb"] | at.Float[at.Array, "b seq adaln_emb"] | None = None,
|
||||
self_attn_mask: at.Bool[at.Array, "b seq seq"] | None = None,
|
||||
*,
|
||||
dtype: at.DTypeLike,
|
||||
) -> TokenSequence:
|
||||
# if x.mask is not None, apply it to the self-attention mask
|
||||
if x.mask is not None:
|
||||
if self_attn_mask is None:
|
||||
self_attn_mask = jnp.ones((x.tokens.shape[0], x.tokens.shape[1], x.tokens.shape[1]), dtype=jnp.bool_)
|
||||
# take the outer product of x.mask with itself to form a full (b, seq, seq) attention mask and then combine
|
||||
# it with the existing attention mask
|
||||
self_attn_mask = jnp.logical_and(self_attn_mask, jnp.logical_and(x.mask[:, None, :], x.mask[:, :, None]))
|
||||
|
||||
def self_attn_fn(y):
|
||||
return self.attn.copy(name="self_attn")(
|
||||
q=y, kv=y, q_pos=x.pos, kv_pos=x.pos, mask=self_attn_mask, dtype=dtype
|
||||
)
|
||||
|
||||
# self-attention
|
||||
tokens = AdaLNGeneral(self_attn_fn)(x.tokens, adaln_cond, dtype=dtype)
|
||||
|
||||
tokens = jax.ad_checkpoint.checkpoint_name(tokens, AFTER_ATTN_CHECKPOINT_NAME)
|
||||
|
||||
# cross-attention
|
||||
if xattn_cond is not None:
|
||||
# if xattn_cond.mask is not None, generate a cross-attention mask
|
||||
if xattn_cond.mask is not None:
|
||||
xattn_mask = einops.repeat(xattn_cond.mask, "b xseq -> b seq xseq", seq=x.tokens.shape[1])
|
||||
else:
|
||||
xattn_mask = None
|
||||
|
||||
def xattn_fn(y):
|
||||
return self.attn.copy(name="cross_attn")(
|
||||
q=y, kv=xattn_cond.tokens, q_pos=x.pos, kv_pos=xattn_cond.pos, mask=xattn_mask, dtype=dtype
|
||||
)
|
||||
|
||||
tokens = AdaLNGeneral(xattn_fn)(tokens, adaln_cond, dtype=dtype)
|
||||
|
||||
tokens = jax.ad_checkpoint.checkpoint_name(tokens, AFTER_XATTN_CHECKPOINT_NAME)
|
||||
|
||||
# mlp
|
||||
tokens = AdaLNGeneral(ft.partial(self.mlp, dtype=dtype))(tokens, adaln_cond, dtype=dtype)
|
||||
|
||||
return x.replace(tokens=tokens)
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
"""Transformer stack with optional AdaLN and cross-attention conditioning.
|
||||
|
||||
AdaLN conditioning is a single vector. Cross-attention conditioning is a sequence of vectors, where the
|
||||
sequence length may be different from the input sequence length. The input, adaln conditioning, and cross-
|
||||
attention conditioning may all have different embedding dimensions.
|
||||
"""
|
||||
|
||||
num_layers: int
|
||||
transformer_block: TransformerBlock
|
||||
self_attn_posemb_strategy: PosembStrategy = PosembStrategy.ADD_AT_BEGINNING
|
||||
xattn_posemb_strategy: PosembStrategy = PosembStrategy.NONE
|
||||
dtype: str = "bfloat16"
|
||||
|
||||
@nn.compact
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
x: TokenSequence,
|
||||
xattn_cond: TokenSequence | None = None,
|
||||
adaln_cond: at.Float[at.Array, "b adaln_emb"] | at.Float[at.Array, "b seq adaln_emb"] | None = None,
|
||||
self_attn_mask: at.Bool[at.Array, "b seq seq"] | None = None,
|
||||
) -> TokenSequence:
|
||||
orig_pos = x.pos # save because we always want to include it in the output sequence
|
||||
# the transformer block always adds positional embeddings, so we disable ADD_AT_ATTN by zeroing them
|
||||
# out here
|
||||
if self.self_attn_posemb_strategy == PosembStrategy.ADD_AT_BEGINNING:
|
||||
x = x.replace(tokens=x.tokens + x.pos)
|
||||
if self.self_attn_posemb_strategy != PosembStrategy.ADD_AT_ATTN:
|
||||
x = x.replace(pos=jnp.zeros_like(x.pos, dtype=self.dtype))
|
||||
x = x.replace(tokens=x.tokens.astype(self.dtype))
|
||||
|
||||
if xattn_cond is not None:
|
||||
if self.xattn_posemb_strategy == PosembStrategy.ADD_AT_BEGINNING:
|
||||
xattn_cond = xattn_cond.replace(tokens=xattn_cond.tokens + xattn_cond.pos)
|
||||
if self.xattn_posemb_strategy != PosembStrategy.ADD_AT_ATTN:
|
||||
xattn_cond = xattn_cond.replace(pos=jnp.zeros_like(xattn_cond.pos, dtype=self.dtype))
|
||||
xattn_cond = xattn_cond.replace(tokens=xattn_cond.tokens.astype(self.dtype))
|
||||
|
||||
if adaln_cond is not None:
|
||||
adaln_cond = adaln_cond.astype(self.dtype)
|
||||
|
||||
def block_call(module, x):
|
||||
return module(x, xattn_cond, adaln_cond, self_attn_mask, dtype=self.dtype), None
|
||||
|
||||
# Enables rematerialization (aka gradient checkpointing). This configuration saves only the post-projection
|
||||
# query, key, and value tensors, as well as the activations after the full attention and cross-attention blocks.
|
||||
# This is based on seqax.
|
||||
block_call_remat = nn.remat(
|
||||
block_call,
|
||||
policy=jax.checkpoint_policies.save_only_these_names(
|
||||
(AFTER_ATTN_CHECKPOINT_NAME, AFTER_XATTN_CHECKPOINT_NAME, QKV_CHECKPOINT_NAME)
|
||||
),
|
||||
)
|
||||
# scanning over layers significantly speeds up compilation time
|
||||
x, _ = nn.scan(
|
||||
block_call_remat,
|
||||
length=self.num_layers,
|
||||
variable_axes={"params": 0}, # create new parameters for each iteration
|
||||
split_rngs={"params": True},
|
||||
)(self.transformer_block, x)
|
||||
|
||||
x = x.replace(tokens=AdaLNGeneral(name="final_norm")(x.tokens, adaln_cond, dtype=self.dtype))
|
||||
|
||||
# restore original posemb for downstream use
|
||||
return x.replace(pos=orig_pos)
|
||||
|
||||
|
||||
Variant = Literal["dummy", "tiny", "small", "base", "large"]
|
||||
|
||||
|
||||
def get_variant(variant: Variant, **kwargs) -> tuple[Transformer, int]:
|
||||
if variant == "dummy":
|
||||
return Transformer(
|
||||
num_layers=2,
|
||||
transformer_block=TransformerBlock(
|
||||
attn=AttentionBlock(num_heads=2),
|
||||
mlp=MLPBlock(dim=4),
|
||||
),
|
||||
**kwargs,
|
||||
), 4
|
||||
if variant == "tiny":
|
||||
return Transformer(
|
||||
num_layers=4,
|
||||
transformer_block=TransformerBlock(
|
||||
attn=AttentionBlock(num_heads=2),
|
||||
mlp=MLPBlock(dim=512),
|
||||
),
|
||||
**kwargs,
|
||||
), 128
|
||||
if variant == "small":
|
||||
return Transformer(
|
||||
num_layers=12,
|
||||
transformer_block=TransformerBlock(
|
||||
attn=AttentionBlock(num_heads=6),
|
||||
mlp=MLPBlock(dim=1536),
|
||||
),
|
||||
**kwargs,
|
||||
), 384
|
||||
if variant == "base":
|
||||
return Transformer(
|
||||
num_layers=12,
|
||||
transformer_block=TransformerBlock(
|
||||
attn=AttentionBlock(num_heads=12),
|
||||
mlp=MLPBlock(dim=3072),
|
||||
),
|
||||
**kwargs,
|
||||
), 768
|
||||
if variant == "large":
|
||||
return Transformer(
|
||||
num_layers=24,
|
||||
transformer_block=TransformerBlock(
|
||||
attn=AttentionBlock(num_heads=16),
|
||||
mlp=MLPBlock(dim=4096),
|
||||
),
|
||||
**kwargs,
|
||||
), 1024
|
||||
raise ValueError(f"Invalid transformer variant: {variant}")
|
||||
307
src/openpi/models/vit.py
Normal file
307
src/openpi/models/vit.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright 2024 Google LLC.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.models import resnet as models_resnet
|
||||
|
||||
Array = Any
|
||||
PRNGKey = Any
|
||||
Shape = tuple[int]
|
||||
Dtype = Any
|
||||
|
||||
|
||||
class IdentityLayer(nn.Module):
|
||||
"""Identity layer, convenient for giving a name to an array."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class AddPositionEmbs(nn.Module):
|
||||
"""Adds learned positional embeddings to the inputs.
|
||||
|
||||
Attributes:
|
||||
posemb_init: positional embedding initializer.
|
||||
"""
|
||||
|
||||
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
|
||||
param_dtype: Dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
"""Applies the AddPositionEmbs module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape `(bs, timesteps, in_dim)`.
|
||||
"""
|
||||
# inputs.shape is (batch_size, seq_len, emb_dim).
|
||||
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
|
||||
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
|
||||
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
|
||||
return inputs + pe
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int
|
||||
dtype: Dtype = jnp.float32
|
||||
param_dtype: Dtype = jnp.float32
|
||||
out_dim: int | None = None
|
||||
dropout_rate: float = 0.1
|
||||
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
|
||||
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, deterministic):
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
|
||||
x = nn.Dense(
|
||||
features=self.mlp_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
inputs
|
||||
)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
output = nn.Dense(
|
||||
features=actual_out_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
x
|
||||
)
|
||||
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Transformer encoder layer.
|
||||
|
||||
Attributes:
|
||||
inputs: input data.
|
||||
mlp_dim: dimension of the mlp on top of attention block.
|
||||
dtype: the dtype of the computation (default: float32).
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout for attention heads.
|
||||
deterministic: bool, deterministic or not (to apply dropout).
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
"""
|
||||
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dtype: Dtype = jnp.float32
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, deterministic):
|
||||
"""Applies Encoder1DBlock module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
deterministic: Dropout will not be applied when set to true.
|
||||
|
||||
Returns:
|
||||
output after transformer encoder block.
|
||||
"""
|
||||
|
||||
# Attention block.
|
||||
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
|
||||
x = nn.LayerNorm(dtype=self.dtype)(inputs)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
dtype=self.dtype,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
broadcast_dropout=False,
|
||||
deterministic=deterministic,
|
||||
dropout_rate=self.attention_dropout_rate,
|
||||
num_heads=self.num_heads,
|
||||
# why isn't this true by default???
|
||||
force_fp32_for_softmax=True,
|
||||
)(x, x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
x = x + inputs
|
||||
|
||||
# MLP block.
|
||||
y = nn.LayerNorm(dtype=self.dtype)(x)
|
||||
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
|
||||
y, deterministic=deterministic
|
||||
)
|
||||
|
||||
return x + y, None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation.
|
||||
|
||||
Attributes:
|
||||
num_layers: number of layers
|
||||
mlp_dim: dimension of the mlp on top of attention block
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout rate in self attention.
|
||||
"""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_layers: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
add_position_embedding: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train):
|
||||
"""Applies Transformer model on the inputs.
|
||||
|
||||
Args:
|
||||
x: Inputs to the layer.
|
||||
train: Set to `True` when training.
|
||||
|
||||
Returns:
|
||||
output of a transformer encoder.
|
||||
"""
|
||||
assert x.ndim == 3 # (batch, len, emb)
|
||||
|
||||
if self.add_position_embedding:
|
||||
x = AddPositionEmbs(
|
||||
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
|
||||
name="posembed_input",
|
||||
)(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
|
||||
|
||||
x = x.astype(self.dtype)
|
||||
# Input Encoder
|
||||
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
|
||||
x, _ = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.num_layers,
|
||||
)(
|
||||
name="encoderblock",
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout_rate=self.dropout_rate,
|
||||
attention_dropout_rate=self.attention_dropout_rate,
|
||||
dtype=self.dtype,
|
||||
num_heads=self.num_heads,
|
||||
)(x, not train)
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""VisionTransformer."""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_classes: int
|
||||
patches: Any
|
||||
transformer: Any
|
||||
hidden_size: int
|
||||
resnet: Any | None = None
|
||||
representation_size: int | None = None
|
||||
classifier: str = "token"
|
||||
head_bias_init: float = 0.0
|
||||
encoder: type[nn.Module] = Encoder
|
||||
model_name: str | None = None
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, train):
|
||||
x = inputs
|
||||
# (Possibly partial) ResNet root.
|
||||
if self.resnet is not None:
|
||||
width = int(64 * self.resnet.width_factor)
|
||||
|
||||
# Root block.
|
||||
x = models_resnet.StdConv(
|
||||
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
|
||||
)(x)
|
||||
x = nn.GroupNorm(name="gn_root")(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
|
||||
|
||||
# ResNet stages.
|
||||
if self.resnet.num_layers:
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
|
||||
)(x)
|
||||
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
|
||||
)(x)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
|
||||
# We can merge s2d+emb into a single conv; it's the same.
|
||||
x = nn.Conv(
|
||||
features=self.hidden_size,
|
||||
kernel_size=self.patches.size,
|
||||
strides=self.patches.size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
)(x)
|
||||
|
||||
# Here, x is a grid of embeddings.
|
||||
|
||||
# (Possibly partial) Transformer.
|
||||
if self.transformer is not None:
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# If we want to add a class token, add it here.
|
||||
if self.classifier in ["token", "token_unpooled"]:
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
|
||||
cls = jnp.tile(cls, [n, 1, 1])
|
||||
x = jnp.concatenate([cls, x], axis=1)
|
||||
|
||||
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
|
||||
|
||||
if self.classifier == "token":
|
||||
x = x[:, 0]
|
||||
elif self.classifier == "gap":
|
||||
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
|
||||
elif self.classifier in ["unpooled", "token_unpooled"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid classifier={self.classifier}")
|
||||
|
||||
if self.representation_size is not None:
|
||||
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
|
||||
x = nn.tanh(x)
|
||||
else:
|
||||
x = IdentityLayer(name="pre_logits")(x)
|
||||
|
||||
if self.num_classes:
|
||||
x = nn.Dense(
|
||||
features=self.num_classes,
|
||||
name="head",
|
||||
kernel_init=nn.initializers.zeros,
|
||||
bias_init=nn.initializers.constant(self.head_bias_init),
|
||||
)(x)
|
||||
return x
|
||||
253
src/openpi/policies/aloha_policy.py
Normal file
253
src/openpi/policies/aloha_policy.py
Normal file
@@ -0,0 +1,253 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
def make_aloha_example() -> dict:
|
||||
return {
|
||||
"qpos": np.ones((14,)),
|
||||
"image": np.random.rand(4, 3, 480, 640).astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
class ActInputsRepack(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# images is [..., num_cams, channel, height, width] of type uint8.
|
||||
# number of cameras (num_cams) depends on the environment.
|
||||
images = np.asarray(data["image"])
|
||||
|
||||
num_cams = images.shape[-4]
|
||||
if num_cams == 4:
|
||||
cam_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
elif num_cams == 1:
|
||||
cam_names = ["cam_high"]
|
||||
else:
|
||||
raise ValueError(f"Expected 1 or 4 cameras, got {num_cams}")
|
||||
|
||||
# `images` have shape [..., cam_idx, channel, height, width].
|
||||
image_splits = [np.squeeze(x, axis=-4) for x in np.split(images, num_cams, axis=-4)]
|
||||
images_dict = dict(zip(cam_names, image_splits, strict=True))
|
||||
|
||||
return {
|
||||
"images": images_dict,
|
||||
"state": data["qpos"],
|
||||
}
|
||||
|
||||
|
||||
class ActOutputsRepack(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
return {"qpos": data["actions"]}
|
||||
|
||||
|
||||
class AlohaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Aloha policy.
|
||||
|
||||
Expected inputs:
|
||||
- images: dict[name, img] where img is [..., channel, height, width]. name must be in EXPECTED_CAMERAS.
|
||||
- state: [..., 14]
|
||||
- actions: [..., action_horizon, action_dim]
|
||||
|
||||
Args:
|
||||
action_dim: The dimension of the action space.
|
||||
delta_action_mask: A boolean mask for the action dimensions. If None, absolute actions are used.
|
||||
adapt_to_pi: If true, will adapt the joint and gripper values to match the pi runtime.
|
||||
"""
|
||||
|
||||
EXPECTED_CAMERAS = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None, adapt_to_pi: bool = False):
|
||||
self._action_dim = action_dim
|
||||
self._delta_action_mask = delta_action_mask
|
||||
self._adapt_to_pi = adapt_to_pi
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self._adapt_to_pi)
|
||||
|
||||
# Get the state. We are padding from 14 to the model action dim.
|
||||
state = transforms.pad_to_dim(data["state"], self._action_dim)
|
||||
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
batch_size = base_image.shape[:-3]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.ones(batch_size, dtype=np.bool_),
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.ones(batch_size, dtype=np.bool_)
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.zeros(batch_size, dtype=np.bool_)
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self._adapt_to_pi)
|
||||
|
||||
if self._delta_action_mask is not None:
|
||||
mask = np.asarray(self._delta_action_mask[:14])
|
||||
actions = actions - np.expand_dims(np.where(mask, state[..., :14], 0), axis=-2)
|
||||
|
||||
inputs["actions"] = transforms.pad_to_dim(actions, self._action_dim)
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class AlohaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Aloha policy.
|
||||
|
||||
Args:
|
||||
delta_action_mask: A boolean mask for the action dimensions. If None, absolute actions are used.
|
||||
adapt_to_pi: If true, will adapt the joint and gripper values to match the pi runtime.
|
||||
"""
|
||||
|
||||
def __init__(self, *, delta_action_mask: Sequence[bool] | None = None, adapt_to_pi: bool = False):
|
||||
self._delta_action_mask = delta_action_mask
|
||||
self._adapt_to_pi = adapt_to_pi
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][..., :14])
|
||||
|
||||
# Apply the delta action mask.
|
||||
if self._delta_action_mask is not None:
|
||||
state = np.asarray(data["state"][..., :14])
|
||||
mask = np.asarray(self._delta_action_mask[:14])
|
||||
actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2)
|
||||
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self._adapt_to_pi)}
|
||||
|
||||
|
||||
def joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
# state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
|
||||
# dim sizes: [6, 1, 6, 1]
|
||||
state = np.asarray(data["state"])
|
||||
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [..., channel, height, width] to [..., height, width, channel].
|
||||
return einops.rearrange(img, "... c h w -> ... h w c")
|
||||
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
data["state"] = state
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
state = joint_flip_mask() * state
|
||||
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
state[..., 6] = gripper_to_angular(state[..., 6])
|
||||
state[..., 13] = gripper_to_angular(state[..., 13])
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
actions = joint_flip_mask() * actions
|
||||
|
||||
actions[..., 6] = gripper_from_angular(actions[..., 6])
|
||||
actions[..., 13] = gripper_from_angular(actions[..., 13])
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
actions = joint_flip_mask() * actions
|
||||
|
||||
actions[..., 6] = gripper_from_angular_inv(actions[..., 6])
|
||||
actions[..., 13] = gripper_from_angular_inv(actions[..., 13])
|
||||
|
||||
return actions
|
||||
38
src/openpi/policies/calvin_policy.py
Normal file
38
src/openpi/policies/calvin_policy.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
class CalvinInputs(transforms.DataTransformFn):
|
||||
def __init__(self, action_dim: int):
|
||||
self._action_dim = action_dim
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
state = transforms.pad_to_dim(data["observation/state"], self._action_dim)
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"rgb_static": data["observation/rgb_static"],
|
||||
"rgb_gripper": data["observation/rgb_gripper"],
|
||||
},
|
||||
"image_mask": {
|
||||
"rgb_static": jnp.ones(1, dtype=jnp.bool_),
|
||||
"rgb_gripper": jnp.ones(1, dtype=jnp.bool_),
|
||||
},
|
||||
}
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class CalvinOutputs(transforms.DataTransformFn):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 15 dims.
|
||||
actions = jnp.asarray(data["actions"][..., :15])
|
||||
return {"actions": actions}
|
||||
53
src/openpi/policies/droid_policy.py
Normal file
53
src/openpi/policies/droid_policy.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
class DroidInputs(transforms.DataTransformFn):
|
||||
def __init__(self, action_dim: int, *, delta_action_mask: Sequence[bool] | None = None):
|
||||
self._action_dim = action_dim
|
||||
self._delta_action_mask = delta_action_mask
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]], axis=1)
|
||||
state = transforms.pad_to_dim(state, self._action_dim)
|
||||
|
||||
base_image = data["observation/exterior_image_1_left"]
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": data["observation/exterior_image_1_left"],
|
||||
"left_wrist_0_rgb": data["observation/wrist_image_left"],
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.ones(1, dtype=np.bool_),
|
||||
"left_wrist_0_rgb": np.ones(1, dtype=np.bool_),
|
||||
"right_wrist_0_rgb": np.zeros(1, dtype=np.bool_),
|
||||
},
|
||||
}
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class DroidOutputs(transforms.DataTransformFn):
|
||||
def __init__(self, *, delta_action_mask: Sequence[bool] | None = None):
|
||||
self._delta_action_mask = delta_action_mask
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 8 dims.
|
||||
actions = np.asarray(data["actions"][..., :8])
|
||||
|
||||
# Apply the delta action mask.
|
||||
if self._delta_action_mask is not None:
|
||||
state = np.asarray(data["state"][..., :8])
|
||||
mask = np.asarray(self._delta_action_mask[:8])
|
||||
actions = actions + np.expand_dims(np.where(mask, state, 0), axis=-2)
|
||||
|
||||
return {"actions": actions}
|
||||
35
src/openpi/policies/libero_policy.py
Normal file
35
src/openpi/policies/libero_policy.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
class LiberoInputs(transforms.DataTransformFn):
|
||||
def __init__(self, action_dim: int):
|
||||
self._action_dim = action_dim
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
state = transforms.pad_to_dim(data["observation/state"], self._action_dim)
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"image": data["observation/image"],
|
||||
"wrist_image": data["observation/wrist_image"],
|
||||
},
|
||||
"image_mask": {
|
||||
"image": jnp.ones(1, dtype=jnp.bool_),
|
||||
"wrist_image": jnp.ones(1, dtype=jnp.bool_),
|
||||
},
|
||||
}
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class LiberoOutputs(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 8 dims.
|
||||
actions = jnp.asarray(data["actions"][..., :8])
|
||||
return {"actions": actions}
|
||||
87
src/openpi/policies/policy.py
Normal file
87
src/openpi/policies/policy.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import flax
|
||||
import flax.traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi import transforms as _transforms
|
||||
from openpi.models import common
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
BasePolicy: TypeAlias = _base_policy.BasePolicy
|
||||
|
||||
|
||||
class Policy(BasePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
model: _model.BaseModel,
|
||||
*,
|
||||
rng: at.KeyArrayLike | None = None,
|
||||
transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
output_transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
):
|
||||
self._model = model
|
||||
self._input_transform = _transforms.CompositeTransform(transforms)
|
||||
self._output_transform = _transforms.CompositeTransform(output_transforms)
|
||||
self._rng = rng or jax.random.key(0)
|
||||
self._sample_kwargs = sample_kwargs or {"num_steps": 10}
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict) -> dict: # type: ignore[misc]
|
||||
inputs = self._input_transform(_make_batch(obs))
|
||||
inputs = jax.tree_util.tree_map(lambda x: jnp.asarray(x), inputs)
|
||||
|
||||
self._rng, sample_rng = jax.random.split(self._rng)
|
||||
outputs = {
|
||||
"state": inputs["state"],
|
||||
"actions": self._model.sample_actions(
|
||||
sample_rng, common.Observation.from_dict(inputs), **self._sample_kwargs
|
||||
),
|
||||
}
|
||||
outputs = self._output_transform(outputs)
|
||||
return _unbatch(jax.device_get(outputs))
|
||||
|
||||
|
||||
class PolicyRecorder(_base_policy.BasePolicy):
|
||||
"""Records the policy's behavior to disk."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
|
||||
self._policy = policy
|
||||
|
||||
logging.info(f"Dumping policy records to: {record_dir}")
|
||||
self._record_dir = pathlib.Path(record_dir)
|
||||
self._record_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._record_step = 0
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict) -> dict: # type: ignore[misc]
|
||||
results = self._policy.infer(obs)
|
||||
|
||||
data = {"inputs": obs, "outputs": results}
|
||||
data = flax.traverse_util.flatten_dict(data, sep="/")
|
||||
|
||||
output_path = self._record_dir / f"step_{self._record_step}"
|
||||
self._record_step += 1
|
||||
|
||||
np.save(output_path, np.asarray(data))
|
||||
return results
|
||||
|
||||
|
||||
def _make_batch(data: at.PyTree[np.ndarray]) -> at.PyTree[np.ndarray]:
|
||||
def _transform(x: np.ndarray) -> np.ndarray:
|
||||
return np.asarray(x)[np.newaxis, ...]
|
||||
|
||||
return jax.tree_util.tree_map(_transform, data)
|
||||
|
||||
|
||||
def _unbatch(data: at.PyTree[np.ndarray]) -> at.PyTree[np.ndarray]:
|
||||
return jax.tree_util.tree_map(lambda x: np.asarray(x[0, ...]), data)
|
||||
123
src/openpi/policies/policy_config.py
Normal file
123
src/openpi/policies/policy_config.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.models import tokenizer
|
||||
import openpi.models.model as _model
|
||||
import openpi.policies.policy as _policy
|
||||
import openpi.shared.download as download
|
||||
from openpi.training import checkpoints as _checkpoints
|
||||
from openpi.training import config as _config
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PolicyConfig:
|
||||
model: _model.BaseModel
|
||||
|
||||
norm_stats: dict[str, transforms.NormStats]
|
||||
|
||||
input_layers: Sequence[transforms.DataTransformFn]
|
||||
output_layers: Sequence[transforms.DataTransformFn]
|
||||
|
||||
default_prompt: str | None = None
|
||||
sample_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
def create_policy(config: PolicyConfig) -> _policy.Policy:
|
||||
"""Creates a default pi0 policy."""
|
||||
return _policy.Policy(
|
||||
config.model,
|
||||
transforms=[
|
||||
*config.input_layers,
|
||||
transforms.Normalize(config.norm_stats),
|
||||
transforms.TokenizePrompt(
|
||||
tokenizer.PaligemmaTokenizer(config.model.max_token_len), default_prompt=config.default_prompt
|
||||
),
|
||||
],
|
||||
output_transforms=[
|
||||
transforms.Unnormalize(config.norm_stats),
|
||||
*config.output_layers,
|
||||
],
|
||||
sample_kwargs=config.sample_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create_trained_policy(
|
||||
train_config: _config.TrainConfig,
|
||||
checkpoint_dir: pathlib.Path | str,
|
||||
*,
|
||||
repack_transforms: transforms.Group | None = None,
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
default_prompt: str | None = None,
|
||||
norm_stats: dict[str, transforms.NormStats] | None = None,
|
||||
) -> _policy.Policy:
|
||||
"""Create a policy from a trained checkpoint.
|
||||
|
||||
Args:
|
||||
train_config: The training config to use to create the model.
|
||||
checkpoint_dir: The directory to load the model from.
|
||||
repack_transforms: Optional transforms that will be applied before any other transforms.
|
||||
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
|
||||
kwargs will be used.
|
||||
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
|
||||
data if it doesn't already exist.
|
||||
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
|
||||
from the checkpoint directory.
|
||||
"""
|
||||
repack_transforms = repack_transforms or transforms.Group()
|
||||
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = train_config.create_model()
|
||||
model = model.set_params(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
|
||||
|
||||
data_config = train_config.data.create(train_config.metadata_dir, model)
|
||||
if norm_stats is None:
|
||||
# We are loading the norm stats from the checkpoint, instead of the metadata dir to make sure
|
||||
# that the policy is using the same normalization stats as the original training process.
|
||||
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets")
|
||||
|
||||
return _policy.Policy(
|
||||
model,
|
||||
transforms=[
|
||||
*repack_transforms.inputs,
|
||||
transforms.InjectDefaultPrompt(default_prompt),
|
||||
*data_config.data_transforms.inputs,
|
||||
transforms.Normalize(norm_stats),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
output_transforms=[
|
||||
*data_config.model_transforms.outputs,
|
||||
transforms.Unnormalize(norm_stats),
|
||||
*data_config.data_transforms.outputs,
|
||||
*repack_transforms.outputs,
|
||||
],
|
||||
sample_kwargs=sample_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def make_bool_mask(*dims: int) -> tuple[bool, ...]:
|
||||
"""Make a boolean mask for the given dimensions.
|
||||
|
||||
Example:
|
||||
make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
Args:
|
||||
dims: The dimensions to make the mask for.
|
||||
|
||||
Returns:
|
||||
A tuple of booleans.
|
||||
"""
|
||||
result = []
|
||||
for dim in dims:
|
||||
if dim > 0:
|
||||
result.extend([True] * (dim))
|
||||
else:
|
||||
result.extend([False] * (-dim))
|
||||
return tuple(result)
|
||||
17
src/openpi/policies/policy_config_test.py
Normal file
17
src/openpi/policies/policy_config_test.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
def test_make_bool_mask():
|
||||
assert _policy_config.make_bool_mask(2, -2, 2) == (True, True, False, False, True, True)
|
||||
assert _policy_config.make_bool_mask(2, 0, 2) == (True, True, True, True)
|
||||
|
||||
|
||||
def test_create_trained_policy():
|
||||
policy = _policy_config.create_trained_policy(
|
||||
_config.get_config("debug"),
|
||||
"s3://openpi-assets/checkpoints/pi0_base",
|
||||
# The base checkpoint doesn't have norm stats.
|
||||
norm_stats={},
|
||||
)
|
||||
assert policy is not None
|
||||
55
src/openpi/policies/policy_test.py
Normal file
55
src/openpi/policies/policy_test.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from openpi_client import action_chunk_broker
|
||||
|
||||
from openpi.models import exported as _exported
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
|
||||
|
||||
def create_policy_config() -> _policy_config.PolicyConfig:
|
||||
model = _exported.PiModel.from_checkpoint("s3://openpi-assets/exported/pi0_aloha_sim/model")
|
||||
|
||||
return _policy_config.PolicyConfig(
|
||||
model=model,
|
||||
norm_stats=model.norm_stats("huggingface_aloha_sim_transfer_cube"),
|
||||
input_layers=[
|
||||
aloha_policy.ActInputsRepack(),
|
||||
aloha_policy.AlohaInputs(
|
||||
action_dim=model.action_dim,
|
||||
delta_action_mask=None,
|
||||
adapt_to_pi=False,
|
||||
),
|
||||
],
|
||||
output_layers=[
|
||||
aloha_policy.AlohaOutputs(
|
||||
delta_action_mask=None,
|
||||
adapt_to_pi=False,
|
||||
),
|
||||
aloha_policy.ActOutputsRepack(),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_infer():
|
||||
config = create_policy_config()
|
||||
policy = _policy_config.create_policy(config)
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
outputs = policy.infer(example)
|
||||
|
||||
assert outputs["qpos"].shape == (config.model.action_horizon, 14)
|
||||
|
||||
|
||||
def test_broker():
|
||||
config = create_policy_config()
|
||||
policy = _policy_config.create_policy(config)
|
||||
|
||||
broker = action_chunk_broker.ActionChunkBroker(
|
||||
policy,
|
||||
# Only execute the first half of the chunk.
|
||||
action_horizon=config.model.action_horizon // 2,
|
||||
)
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
for _ in range(config.model.action_horizon):
|
||||
outputs = broker.infer(example)
|
||||
assert outputs["qpos"].shape == (14,)
|
||||
0
src/openpi/py.typed
Normal file
0
src/openpi/py.typed
Normal file
55
src/openpi/serving/websocket_policy_server.py
Normal file
55
src/openpi/serving/websocket_policy_server.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
import websockets.asyncio.server
|
||||
import websockets.frames
|
||||
|
||||
|
||||
class WebsocketPolicyServer:
|
||||
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
||||
|
||||
Currently only implements the `load` and `infer` methods.
|
||||
TODO: Implement the other methods.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, host: str = "0.0.0.0", port: int = 8000) -> None:
|
||||
self._policy = policy
|
||||
self._host = host
|
||||
self._port = port
|
||||
logging.getLogger("websockets.server").setLevel(logging.INFO)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self):
|
||||
async with websockets.asyncio.server.serve(
|
||||
self._handler,
|
||||
self._host,
|
||||
self._port,
|
||||
compression=None,
|
||||
max_size=None,
|
||||
) as server:
|
||||
await server.serve_forever()
|
||||
|
||||
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
|
||||
logging.info(f"Connection from {websocket.remote_address} opened")
|
||||
packer = msgpack_numpy.Packer()
|
||||
|
||||
while True:
|
||||
try:
|
||||
obs = msgpack_numpy.unpackb(await websocket.recv())
|
||||
action = self._policy.infer(obs)
|
||||
await websocket.send(packer.pack(action))
|
||||
except websockets.ConnectionClosed:
|
||||
logging.info(f"Connection from {websocket.remote_address} closed")
|
||||
break
|
||||
except Exception:
|
||||
await websocket.send(traceback.format_exc())
|
||||
await websocket.close(
|
||||
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
||||
reason="Internal server error. Traceback included in previous frame.",
|
||||
)
|
||||
raise
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user