Initial commit
This commit is contained in:
3
.dockerignore
Normal file
3
.dockerignore
Normal file
@@ -0,0 +1,3 @@
|
||||
.venv
|
||||
checkpoints
|
||||
data
|
||||
16
.github/CODEOWNERS
vendored
Normal file
16
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
# The CODEOWNERS file defines individuals or teams that are automatically requested for
|
||||
# review when someone opens a pull request that modifies certain code. When a draft pull
|
||||
# request is marked as ready for review, code owners are automatically notified.
|
||||
#
|
||||
# See: https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
#
|
||||
# This is a comment.
|
||||
# Each line is a file pattern followed by one or more owners.
|
||||
|
||||
# Global owners.
|
||||
* @jimmyt857 @Michael-Equi @uzhilinsky
|
||||
|
||||
src/openpi/models/ @kvablack @uzhilinsky
|
||||
src/openpi/training/ @kvablack @uzhilinsky
|
||||
|
||||
scripts/ @jimmyt857 @kvablack @uzhilinsky
|
||||
17
.github/workflows/pre-commit.yml
vendored
Normal file
17
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
name: pre-commit
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v3
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
26
.github/workflows/test.yml
vendored
Normal file
26
.github/workflows/test.yml
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
name: Test
|
||||
on:
|
||||
pull_request:
|
||||
branches:
|
||||
- "*"
|
||||
|
||||
jobs:
|
||||
run_tests:
|
||||
name: Run Tests
|
||||
runs-on: verylarge
|
||||
env:
|
||||
GIT_LFS_SKIP_SMUDGE: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
|
||||
- name: Set up Python
|
||||
run: uv python install
|
||||
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
|
||||
- name: Run tests
|
||||
run: uv run pytest --strict-markers -m "not manual"
|
||||
168
.gitignore
vendored
Normal file
168
.gitignore
vendored
Normal file
@@ -0,0 +1,168 @@
|
||||
# Data directories.
|
||||
assets/
|
||||
checkpoints/
|
||||
data/
|
||||
wandb/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
[submodule "third_party/aloha"]
|
||||
path = third_party/aloha
|
||||
url = git@github.com:Physical-Intelligence/aloha.git
|
||||
[submodule "third_party/libero"]
|
||||
path = third_party/libero
|
||||
url = git@github.com:Lifelong-Robot-Learning/LIBERO.git
|
||||
16
.pre-commit-config.yaml
Normal file
16
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
exclude: third_party/
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# uv version.
|
||||
rev: 0.5.14
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.8.6
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "charliermarsh.ruff",
|
||||
"editor.formatOnSave": true,
|
||||
},
|
||||
"python.testing.pytestArgs": [
|
||||
"src"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
||||
33
CONTRIBUTING.md
Normal file
33
CONTRIBUTING.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Contributing to openpi
|
||||
|
||||
We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
|
||||
|
||||
## Issues and feature requests
|
||||
|
||||
You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
|
||||
|
||||
If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
|
||||
|
||||
- Your OS type and version and the version of Python you are using
|
||||
- Code that allows us to reproduce your bug, including all dependencies
|
||||
- Traceback of any exception
|
||||
- Any other information that would help us, such as a screenshot
|
||||
|
||||
In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
|
||||
|
||||
If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
|
||||
|
||||
- The motivation for the feature
|
||||
- A description of the problem you are trying to solve or your use case
|
||||
- Enough information for us to understand the nature of the request
|
||||
- Some information for how you intend to use it (this might help us in understanding the motivation!)
|
||||
|
||||
We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
|
||||
|
||||
## Submitting a pull request
|
||||
|
||||
If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
|
||||
|
||||
- Make sure that your PR has a clear title and description
|
||||
- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
|
||||
- Make sure your PR passes all tests
|
||||
201
LICENSE
Normal file
201
LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
184
README.md
Normal file
184
README.md
Normal file
@@ -0,0 +1,184 @@
|
||||
# openpi
|
||||
|
||||
openpi holds open-source models and packages for robotics, published by the [Physical Intelligence team](https://www.physicalintelligence.company/).
|
||||
|
||||
Currently, this repo contains two types of models:
|
||||
- the [π₀ model](https://www.physicalintelligence.company/blog/pi0), a flow-based diffusion vision-language-action model (VLA)
|
||||
- the [π₀-FAST model](https://www.physicalintelligence.company/research/fast), an autoregressive VLA, based on the FAST action tokenizer.
|
||||
|
||||
For both models, we provide _base model_ checkpoints, pre-trained on 10k+ hours of robot data, and examples for using them out of the box or fine-tuning them to your own datasets.
|
||||
|
||||
This is an experiment: $\pi_0$ was developed for our own robots, which differ from the widely used platforms such as [ALOHA](https://tonyzhaozh.github.io/aloha/) and [DROID](https://droid-dataset.github.io/), and though we are optimistic that researchers and practitioners will be able to run creative new experiments adapting $\pi_0$ to their own platforms, we do not expect every such attempt to be successful. All this is to say: $\pi_0$ may or may not work for you, but you are welcome to try it and see!
|
||||
|
||||
|
||||
## Requirements
|
||||
|
||||
To run the models in this repository, you will need an NVIDIA GPU with at least the following specifications. These estimations assume a single GPU, but you can also use multiple GPUs with model parallelism to reduce per-GPU memory requirements by configuring `fsdp_devices` in the training config. Please also note that the current training script does not yet support multi-node training.
|
||||
|
||||
| Mode | Memory Required | Example GPU |
|
||||
| ------------------ | --------------- | ------------------ |
|
||||
| Inference | > 8 GB | RTX 4090 |
|
||||
| Fine-Tuning (LoRA) | > 22.5 GB | RTX 4090 |
|
||||
| Fine-Tuning (Full) | > 70 GB | A100 (80GB) / H100 |
|
||||
|
||||
The repo has been tested with Ubuntu 22.04, we do not currently support other operating systems.
|
||||
|
||||
## Installation
|
||||
|
||||
When cloning this repo, make sure to update submodules:
|
||||
|
||||
```bash
|
||||
git clone --recurse-submodules git@github.com:Physical-Intelligence/openpi.git
|
||||
|
||||
# Or if you already cloned the repo:
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
We use [uv](https://docs.astral.sh/uv/) to manage Python dependencies. See the [uv installation instructions](https://docs.astral.sh/uv/getting-started/installation/) to set it up. Once uv is installed, run the following to set up the environment:
|
||||
|
||||
```bash
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync
|
||||
```
|
||||
|
||||
NOTE: `GIT_LFS_SKIP_SMUDGE=1` is needed to pull LeRobot as a dependency.
|
||||
|
||||
**Docker**: As an alternative to uv installation, we provide instructions for installing openpi using Docker. If you encounter issues with your system setup, consider using Docker to simplify installation. See [Docker Setup](docs/docker.md) for more details.
|
||||
|
||||
|
||||
|
||||
|
||||
## Model Checkpoints
|
||||
|
||||
### Base Models
|
||||
We provide multiple base VLA model checkpoints. These checkpoints have been pre-trained on 10k+ hours of robot data, and can be used for fine-tuning.
|
||||
|
||||
| Model | Use Case | Description | Checkpoint Path |
|
||||
| ------------ | ----------- | ----------------------------------------------------------------------------------------------------------- | ---------------------------------------------- |
|
||||
| $\pi_0$ | Fine-Tuning | Base diffusion [π₀ model](https://www.physicalintelligence.company/blog/pi0) for fine-tuning | `s3://openpi-assets/checkpoints/pi0_base` |
|
||||
| $\pi_0$-FAST | Fine-Tuning | Base autoregressive [π₀-FAST model](https://www.physicalintelligence.company/research/fast) for fine-tuning | `s3://openpi-assets/checkpoints/pi0_fast_base` |
|
||||
|
||||
### Fine-Tuned Models
|
||||
We also provide "expert" checkpoints for various robot platforms and tasks. These models are fine-tuned from the base models above and intended to run directly on the target robot. These may or may not work on your particular robot. Since these checkpoints were fine-tuned on relatively small datasets collected with more widely available robots, such as ALOHA and the DROID Franka setup, they might not generalize to your particular setup, though we found some of these, especially the DROID checkpoint, to generalize quite broadly in practice.
|
||||
|
||||
| Model | Use Case | Description | Checkpoint Path |
|
||||
| ------------------------ | --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------- |
|
||||
| $\pi_0$-FAST-DROID | Inference | $\pi_0$-FAST model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), can perform a wide range of simple table-top manipulation tasks 0-shot in new scenes on the DROID robot platform | `s3://openpi-assets/checkpoints/pi0_fast_droid` |
|
||||
| $\pi_0$-DROID | Fine-Tuning | $\pi_0$ model fine-tuned on the [DROID dataset](https://droid-dataset.github.io/), faster inference than $\pi_0$-FAST-DROID, but may not follow language commands as well | `s3://openpi-assets/checkpoints/pi0_droid` |
|
||||
| $\pi_0$-ALOHA-towel | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can fold diverse towels 0-shot on [ALOHA](https://tonyzhaozh.github.io/aloha/) robot platforms | `s3://openpi-assets/checkpoints/pi0_aloha_towel` |
|
||||
| $\pi_0$-ALOHA-tupperware | Inference | $\pi_0$ model fine-tuned on internal ALOHA data, can unpack food from a tupperware container | `s3://openpi-assets/checkpoints/pi0_aloha_tupperware` |
|
||||
| $\pi_0$-ALOHA-pen-uncap | Inference | $\pi_0$ model fine-tuned on [public ALOHA data](XXX), can uncap a pen | `s3://openpi-assets/checkpoints/pi0_aloha_pen_uncap` |
|
||||
|
||||
|
||||
By default, checkpoints are automatically downloaded from `s3://openpi-assets` and are cached in `~/.cache/openpi` when needed. You can overwrite the download path by setting the `OPENPI_DATA_HOME` environment variable.
|
||||
|
||||
|
||||
|
||||
|
||||
## Running Inference for a Pre-Trained Model
|
||||
|
||||
Our pre-trained model checkpoints can be run with a few lines of code (here our $\pi_0$-FAST-DROID model):
|
||||
```python
|
||||
from openpi.training import config
|
||||
from openpi.policies import policy_config
|
||||
from openpi.shared import download
|
||||
|
||||
config = config.get_config("pi0_fast_droid")
|
||||
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_fast_droid")
|
||||
|
||||
# Create a trained policy.
|
||||
policy = policy_config.create_trained_policy(config, checkpoint_dir)
|
||||
|
||||
# Run inference on a dummy example.
|
||||
example = {
|
||||
"observation/exterior_image_1_left": ...,
|
||||
"observation/wrist_image_left": ...,
|
||||
...
|
||||
"prompt": "pick up the fork"
|
||||
}
|
||||
action_chunk = policy.infer(example)["actions"]
|
||||
```
|
||||
You can also test this out in the [example notebook](examples/inference.ipynb).
|
||||
|
||||
We provide detailed step-by-step examples for running inference of our pre-trained checkpoints on [DROID](examples/droid/README.md) and [ALOHA](examples/aloha_real/README.md) robots.
|
||||
|
||||
**Remote Inference**: We provide [examples and code](docs/remote_inference.md) for running inference of our models **remotely**: the model can run on a different server and stream actions to the robot via a websocket connection. This makes it easy to use more powerful GPUs off-robot and keep robot and policy environments separate.
|
||||
|
||||
**Test inference without a robot**: We provide a [script](examples/simple_client/README.md) for testing inference without a robot. This script will generate a random observation and run inference with the model. See [here](examples/simple_client/README.md) for more details.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## Fine-Tuning Base Models on Your Own Data
|
||||
|
||||
We will fine-tune the $\pi_0$-FAST model on the [Libero dataset](https://libero-project.github.io/datasets) as a running example for how to fine-tune a base model on your own data. We will explain three steps:
|
||||
1. Convert your data to a LeRobot dataset (which we use for training)
|
||||
2. Defining training configs and running training
|
||||
3. Spinning up a policy server and running inference
|
||||
|
||||
### 1. Convert your data to a LeRobot dataset
|
||||
|
||||
We provide a minimal example script for converting Libero data to a LeRobot dataset in [`examples/libero/convert_libero_data_to_lerobot.py`](examples/libero/convert_libero_data_to_lerobot.py). You can easily modify it to convert your own data! You can download the raw Libero dataset from [here](https://huggingface.co/datasets/openvla/modified_libero_rlds), and run the script with:
|
||||
|
||||
```bash
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/libero/data
|
||||
```
|
||||
|
||||
### 2. Defining training configs and running training
|
||||
|
||||
To fine-tune a base model on your own data, you need to define configs for data processing and training. We provide example configs with detailed comments for Libero below, which you can modify for your own dataset:
|
||||
|
||||
- [`LiberoInputs` and `LiberoOutputs`](src/openpi/policies/libero_policy.py): Defines the data mapping from the Libero environment to the model and vice versa. Will be used for both, training and inference.
|
||||
- [`LeRobotLiberoDataConfig`](src/openpi/training/config.py): Defines how to process raw Libero data from LeRobot dataset for training.
|
||||
- [`TrainConfig`](src/openpi/training/config.py): Defines fine-tuning hyperparameters, data config, and weight loader.
|
||||
|
||||
We provide example fine-tuning configs for both, [π₀](src/openpi/training/config.py) and [π₀-FAST](src/openpi/training/config.py) on Libero data.
|
||||
|
||||
Before we can run training, we need to compute the normalization statistics for the training data. Run the script below with the name of your training config:
|
||||
|
||||
```bash
|
||||
uv run scripts/compute_norm_stats.py --config-name pi0_fast_libero
|
||||
```
|
||||
|
||||
Now we can kick off training with the following command (the `--overwrite` flag is used to overwrite existing checkpoints if you rerun fine-tuning with the same config):
|
||||
|
||||
```bash
|
||||
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run scripts/train.py pi0_fast_libero --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
The command will log training progress to the console and save checkpoints to the `checkpoints` directory. You can also monitor training progress on the Weights & Biases dashboard. For maximally using the GPU memory, set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training -- this enables JAX to use up to 90% of the GPU memory (vs. the default of 75%).
|
||||
|
||||
### 3. Spinning up a policy server and running inference
|
||||
|
||||
Once training is complete, we can run inference by spinning up a policy server and then querying it from a Libero evaluation script. Launching a model server is easy (we use the checkpoint for iteration 20,000 for this example, modify as needed):
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_libero --policy.dir=checkpoints/pi0_fast_libero/my_experiment/20000
|
||||
```
|
||||
|
||||
This will spin up a server that listens on port 8000 and waits for observations to be sent to it. We can then run the Libero evaluation script to query the server. For instructions how to install Libero and run the evaluation script, see the [Libero README](examples/libero/README.md).
|
||||
|
||||
|
||||
### More Examples
|
||||
|
||||
We provide more examples for how to fine-tune and run inference with our models on the ALOHA platform in the following READMEs:
|
||||
- [ALOHA Simulator](examples/aloha_sim)
|
||||
- [ALOHA Real](examples/aloha_real)
|
||||
|
||||
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
We will collect common issues and their solutions here. If you encounter an issue, please check here first. If you can't find a solution, please file an issue on the repo (see [here](CONTRIBUTING.md) for guidelines).
|
||||
|
||||
| Issue | Resolution |
|
||||
| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `uv sync` fails with dependency conflicts | Try removing the virtual environment directory (`rm -rf .venv`) and running `uv sync` again. If issues persist, check that you have the latest version of `uv` installed (`uv self update`). |
|
||||
| Training runs out of GPU memory | Make sure you set `XLA_PYTHON_CLIENT_MEM_FRACTION=0.9` before running training to allow JAX to use more GPU memory. You can also try reducing the batch size in your training config. |
|
||||
| Policy server connection errors | Check that the server is running and listening on the expected port. Verify network connectivity and firewall settings between client and server. |
|
||||
| Missing norm stats error when training | Run `scripts/compute_norm_stats.py` with your config name before starting training. |
|
||||
| Dataset download fails | Check your internet connection. If using `local_files_only=True`, verify the dataset exists locally. For HuggingFace datasets, ensure you're logged in (`huggingface-cli login`). |
|
||||
| CUDA/GPU errors | Verify NVIDIA drivers and CUDA toolkit are installed correctly. For Docker, ensure nvidia-container-toolkit is installed. Check GPU compatibility. |
|
||||
| Import errors when running examples | Make sure you've installed all dependencies with `uv sync` and activated the virtual environment. Some examples may have additional requirements listed in their READMEs. |
|
||||
| Action dimensions mismatch | Verify your data processing transforms match the expected input/output dimensions of your robot. Check the action space definitions in your policy classes. |
|
||||
|
||||
7
docs/docker.md
Normal file
7
docs/docker.md
Normal file
@@ -0,0 +1,7 @@
|
||||
### Docker Setup
|
||||
|
||||
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
||||
|
||||
Docker installation instructions are [here](https://docs.docker.com/engine/install/). If using a GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html). If your host machine is Ubuntu 22.04, you can use the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
|
||||
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
||||
42
docs/remote_inference.md
Normal file
42
docs/remote_inference.md
Normal file
@@ -0,0 +1,42 @@
|
||||
|
||||
# Running openpi models remotely
|
||||
|
||||
We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
|
||||
|
||||
## Starting a remote policy server
|
||||
|
||||
To start a remote policy server, you can simply run the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
|
||||
```
|
||||
|
||||
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
|
||||
```
|
||||
|
||||
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
||||
|
||||
## Querying the remote policy server from your robot code
|
||||
|
||||
We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
|
||||
|
||||
First, install the `openpi-client` package in your robot environment:
|
||||
|
||||
```bash
|
||||
cd $OPENPI_ROOT/packages/openpi-client
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
||||
|
||||
```python
|
||||
from openpi_client import websocket_client_policy
|
||||
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(host="10.32.255.0", port=8000)
|
||||
action_chunk = policy_client.infer(example)["actions"]
|
||||
```
|
||||
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `example` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
70
examples/aloha_real/Dockerfile
Normal file
70
examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:0e12e4db836e78c74c4b04c6d16f185d9a18d2b13cf5580747efa075eb6dc6e0
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
126
examples/aloha_real/README.md
Normal file
126
examples/aloha_real/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../../openpi/docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python examples/aloha_real/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch --wait aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
||||
```
|
||||
|
||||
## **ALOHA Checkpoint Guide**
|
||||
|
||||
|
||||
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
||||
|
||||
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **Toast Task**
|
||||
|
||||
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_base`
|
||||
- **Prompt**: "take the toast out of the toaster"
|
||||
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
||||
- **Object Distribution**:
|
||||
- Works on both real toast and rubber fake toast
|
||||
- Compatible with standard 2-slice toasters
|
||||
- Works with plates of varying colors
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
||||
|
||||
- The toaster should be positioned in the top-left quadrant of the workspace.
|
||||
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
||||
- The plate should be placed roughly in the lower-center of the workspace.
|
||||
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
||||
|
||||
|
||||
### **Towel Task**
|
||||
|
||||
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_towel`
|
||||
- **Prompt**: "fold the towel"
|
||||
- **Object Distribution**:
|
||||
- Works on towels of varying solid colors
|
||||
- Performance is worse on heavily textured or striped towels
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
||||
|
||||
- The towel should be flattened and roughly centered on the table.
|
||||
- Choose a towel that does not blend in with the table surface.
|
||||
|
||||
|
||||
### **Tupperware Task**
|
||||
|
||||
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
||||
|
||||
- **Checkpoint path**: `s3://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
||||
- **Prompt**: "open the tupperware and put the food on the plate"
|
||||
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
||||
- **Object Distribution**:
|
||||
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
||||
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
||||
- The policy has seen plates of varying solid colors.
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
||||
|
||||
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
||||
- Positioning:
|
||||
- Tupperware should be on the left.
|
||||
- Plate should be on the right or bottom.
|
||||
- The tupperware flap should point toward the plate.
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
||||
|
||||
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
||||
|
||||
|
||||
2. Define a training config that uses the custom dataset.
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
66
examples/aloha_real/compose.yml
Normal file
66
examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
examples/aloha_real/constants.py
Normal file
71
examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
272
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file
272
examples/aloha_real/convert_aloha_data_to_lerobot.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||
|
||||
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DatasetConfig:
|
||||
use_videos: bool = True
|
||||
tolerance_s: float = 0.0001
|
||||
image_writer_processes: int = 10
|
||||
image_writer_threads: int = 5
|
||||
video_backend: str | None = None
|
||||
|
||||
|
||||
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||
|
||||
|
||||
def create_empty_dataset(
|
||||
repo_id: str,
|
||||
robot_type: str,
|
||||
mode: Literal["video", "image"] = "video",
|
||||
*,
|
||||
has_velocity: bool = False,
|
||||
has_effort: bool = False,
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
) -> LeRobotDataset:
|
||||
motors = [
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_effort:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
if Path(LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=dataset_config.use_videos,
|
||||
tolerance_s=dataset_config.tolerance_s,
|
||||
image_writer_processes=dataset_config.image_writer_processes,
|
||||
image_writer_threads=dataset_config.image_writer_threads,
|
||||
video_backend=dataset_config.video_backend,
|
||||
)
|
||||
|
||||
|
||||
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
|
||||
|
||||
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.imdecode(data, 1))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(
|
||||
ep_path: Path,
|
||||
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(
|
||||
ep,
|
||||
[
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
],
|
||||
)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
hdf5_files: list[Path],
|
||||
task: str,
|
||||
episodes: list[int] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
raw_repo_id: str | None = None,
|
||||
task: str = "DEBUG",
|
||||
*,
|
||||
episodes: list[int] | None = None,
|
||||
push_to_hub: bool = True,
|
||||
is_mobile: bool = False,
|
||||
mode: Literal["video", "image"] = "image",
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
if raw_repo_id is None:
|
||||
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||
mode=mode,
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
dataset_config=dataset_config,
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task=task,
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(port_aloha)
|
||||
57
examples/aloha_real/env.py
Normal file
57
examples/aloha_real/env.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional # noqa: UP035
|
||||
|
||||
import einops
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
||||
render_height: int = 224,
|
||||
render_width: int = 224,
|
||||
) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
for cam_name in obs["images"]:
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
||||
)
|
||||
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
||||
|
||||
return {
|
||||
"state": obs["qpos"],
|
||||
"images": obs["images"],
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["actions"])
|
||||
51
examples/aloha_real/main.py
Normal file
51
examples/aloha_real/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
num_episodes: int = 1
|
||||
max_episode_steps: int = 1000
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
||||
|
||||
metadata = ws_client_policy.get_server_metadata()
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=ws_client_policy,
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
num_episodes=args.num_episodes,
|
||||
max_episode_steps=args.max_episode_steps,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
171
examples/aloha_real/real_env.py
Normal file
171
examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional, List
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
# This is the reset position that is used by the standard Aloha runtime.
|
||||
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
||||
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first open then close. Then change back to PWM mode"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
||||
18
examples/aloha_real/requirements.in
Normal file
18
examples/aloha_real/requirements.in
Normal file
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
examples/aloha_real/requirements.txt
Normal file
156
examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
examples/aloha_real/robot_utils.py
Normal file
275
examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
36
examples/aloha_real/video_display.py
Normal file
36
examples/aloha_real/video_display.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
examples/aloha_sim/Dockerfile
Normal file
41
examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
examples/aloha_sim/README.md
Normal file
36
examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
42
examples/aloha_sim/compose.yml
Normal file
42
examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
examples/aloha_sim/env.py
Normal file
56
examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
img = gym_obs["pixels"]["top"]
|
||||
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
|
||||
return {
|
||||
"state": gym_obs["agent_pos"],
|
||||
"images": {"cam_high": img},
|
||||
}
|
||||
55
examples/aloha_sim/main.py
Normal file
55
examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_dir),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
8
examples/aloha_sim/requirements.in
Normal file
8
examples/aloha_sim/requirements.in
Normal file
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
examples/aloha_sim/requirements.txt
Normal file
132
examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
40
examples/aloha_sim/saver.py
Normal file
40
examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._out_dir = out_dir
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["images"]["cam_high"] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
||||
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
||||
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
||||
|
||||
logging.info(f"Saving video to {out_path}")
|
||||
imageio.mimwrite(
|
||||
out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
46
examples/droid/README.md
Normal file
46
examples/droid/README.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# Run DROID
|
||||
|
||||
This example shows how to run the fine-tuned $\pi_0$-FAST-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). We also offer a $\pi_0$-DROID model that is fine-tuned from $\pi_0$ and uses flow action decoding. You can use it by replacing `pi0_fast_droid` with `pi0_droid` in the commands below. In practice, we find that out-of-the-box, the $\pi_0$-FAST-DROID model is better at following language commands, so we recommend it as the default checkpoint for DROID evaluation. If you want to fine-tune on a DROID task that requires a fast-to-inference policy, you may still want to consider using the $\pi_0$-DROID model, since it decodes faster. For more details, please see the [FAST paper](https://pi.website/research/fast).
|
||||
|
||||
|
||||
## Step 1: Start a policy server
|
||||
|
||||
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
||||
|
||||
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
||||
2. Start the OpenPI server via the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=s3://openpi-assets/checkpoints/pi0_fast_droid
|
||||
```
|
||||
|
||||
You can also run the equivalent command below:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=DROID
|
||||
```
|
||||
|
||||
## Step 2: Run the DROID robot
|
||||
|
||||
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
||||
2. On the control laptop, activate your DROID conda environment.
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explore` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
||||
```
|
||||
|
||||
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
||||
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
||||
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
||||
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
||||
237
examples/droid/main.py
Normal file
237
examples/droid/main.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# ruff: noqa
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from droid.robot_env import RobotEnv
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
# Hardware parameters
|
||||
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
||||
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
||||
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
||||
|
||||
# Policy parameters
|
||||
external_camera: str | None = (
|
||||
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
||||
)
|
||||
|
||||
# Rollout parameters
|
||||
max_timesteps: int = 600
|
||||
# How many actions to execute from a predicted action chunk before querying policy server again
|
||||
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
||||
open_loop_horizon: int = 8
|
||||
|
||||
# Remote server parameters
|
||||
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
||||
remote_port: int = (
|
||||
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
||||
)
|
||||
|
||||
|
||||
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
||||
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
||||
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
||||
@contextlib.contextmanager
|
||||
def prevent_keyboard_interrupt():
|
||||
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
||||
interrupted = False
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def handler(signum, frame):
|
||||
nonlocal interrupted
|
||||
interrupted = True
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if interrupted:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
||||
assert (
|
||||
args.external_camera is not None and args.external_camera in ["left", "right"]
|
||||
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
||||
|
||||
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
||||
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
||||
print("Created the droid env!")
|
||||
|
||||
# Connect to the policy server
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
||||
|
||||
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
||||
|
||||
while True:
|
||||
instruction = input("Enter instruction: ")
|
||||
|
||||
# Rollout parameters
|
||||
actions_from_chunk_completed = 0
|
||||
pred_action_chunk = None
|
||||
|
||||
# Prepare to save video of rollout
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
||||
video = []
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
args,
|
||||
env.get_observation(),
|
||||
# Save the first observation to disk
|
||||
save_to_disk=t_step == 0,
|
||||
)
|
||||
|
||||
video.append(curr_obs[f"{args.external_camera}_image"])
|
||||
|
||||
# Send websocket request to policy server if it's time to predict a new chunk
|
||||
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
||||
actions_from_chunk_completed = 0
|
||||
|
||||
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
||||
# and improve latency.
|
||||
request_data = {
|
||||
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
||||
curr_obs[f"{args.external_camera}_image"], 224, 224
|
||||
),
|
||||
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
||||
"observation/joint_position": curr_obs["joint_position"],
|
||||
"observation/gripper_position": curr_obs["gripper_position"],
|
||||
"prompt": instruction,
|
||||
}
|
||||
|
||||
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
||||
# Ctrl+C will be handled after the server call is complete
|
||||
with prevent_keyboard_interrupt():
|
||||
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
||||
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
||||
assert pred_action_chunk.shape == (10, 8)
|
||||
|
||||
# Select current action to execute from chunk
|
||||
action = pred_action_chunk[actions_from_chunk_completed]
|
||||
actions_from_chunk_completed += 1
|
||||
|
||||
# Binarize gripper action
|
||||
if action[-1].item() > 0.5:
|
||||
# action[-1] = 1.0
|
||||
action = np.concatenate([action[:-1], np.ones((1,))])
|
||||
else:
|
||||
# action[-1] = 0.0
|
||||
action = np.concatenate([action[:-1], np.zeros((1,))])
|
||||
|
||||
# clip all dimensions of action to [-1, 1]
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
video = np.stack(video)
|
||||
save_filename = "video_" + timestamp
|
||||
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
||||
|
||||
success: str | float | None = None
|
||||
while not isinstance(success, float):
|
||||
success = input(
|
||||
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
||||
)
|
||||
if success == "y":
|
||||
success = 1.0
|
||||
elif success == "n":
|
||||
success = 0.0
|
||||
|
||||
success = float(success) / 100
|
||||
if not (0 <= success <= 1):
|
||||
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
||||
|
||||
df = df.append(
|
||||
{
|
||||
"success": success,
|
||||
"duration": t_step,
|
||||
"video_filename": save_filename,
|
||||
},
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
||||
break
|
||||
env.reset()
|
||||
|
||||
os.makedirs("results", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
||||
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
||||
df.to_csv(csv_filename)
|
||||
print(f"Results saved to {csv_filename}")
|
||||
|
||||
|
||||
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
||||
image_observations = obs_dict["image"]
|
||||
left_image, right_image, wrist_image = None, None, None
|
||||
for key in image_observations:
|
||||
# Note the "left" below refers to the left camera in the stereo pair.
|
||||
# The model is only trained on left stereo cams, so we only feed those.
|
||||
if args.left_camera_id in key and "left" in key:
|
||||
left_image = image_observations[key]
|
||||
elif args.right_camera_id in key and "left" in key:
|
||||
right_image = image_observations[key]
|
||||
elif args.wrist_camera_id in key and "left" in key:
|
||||
wrist_image = image_observations[key]
|
||||
|
||||
# Drop the alpha dimension
|
||||
left_image = left_image[..., :3]
|
||||
right_image = right_image[..., :3]
|
||||
wrist_image = wrist_image[..., :3]
|
||||
|
||||
# Convert to RGB
|
||||
left_image = left_image[..., ::-1]
|
||||
right_image = right_image[..., ::-1]
|
||||
wrist_image = wrist_image[..., ::-1]
|
||||
|
||||
# In addition to image observations, also capture the proprioceptive state
|
||||
robot_state = obs_dict["robot_state"]
|
||||
cartesian_position = np.array(robot_state["cartesian_position"])
|
||||
joint_position = np.array(robot_state["joint_positions"])
|
||||
gripper_position = np.array([robot_state["gripper_position"]])
|
||||
|
||||
# Save the images to disk so that they can be viewed live while the robot is running
|
||||
# Create one combined image to make live viewing easy
|
||||
if save_to_disk:
|
||||
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
||||
combined_image = Image.fromarray(combined_image)
|
||||
combined_image.save("robot_camera_views.png")
|
||||
|
||||
return {
|
||||
"left_image": left_image,
|
||||
"right_image": right_image,
|
||||
"wrist_image": wrist_image,
|
||||
"cartesian_position": cartesian_position,
|
||||
"joint_position": joint_position,
|
||||
"gripper_position": gripper_position,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args: Args = tyro.cli(Args)
|
||||
main(args)
|
||||
137
examples/inference.ipynb
Normal file
137
examples/inference.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import dataclasses\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from openpi.models import model as _model\n",
|
||||
"from openpi.policies import droid_policy\n",
|
||||
"from openpi.policies import policy_config as _policy_config\n",
|
||||
"from openpi.shared import download\n",
|
||||
"from openpi.training import config as _config\n",
|
||||
"from openpi.training import data_loader as _data_loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Policy inference\n",
|
||||
"\n",
|
||||
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
||||
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
||||
"\n",
|
||||
"# Create a trained policy.\n",
|
||||
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
||||
"\n",
|
||||
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
||||
"example = droid_policy.make_droid_example()\n",
|
||||
"result = policy.infer(example)\n",
|
||||
"\n",
|
||||
"# Delete the policy to free up memory.\n",
|
||||
"del policy\n",
|
||||
"\n",
|
||||
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Working with a live model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
||||
"\n",
|
||||
"checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
||||
"key = jax.random.key(0)\n",
|
||||
"\n",
|
||||
"# Create a model from the checkpoint.\n",
|
||||
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
||||
"\n",
|
||||
"# We can create fake observations and actions to test the model.\n",
|
||||
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reduce the batch size to reduce memory usage.\n",
|
||||
"config = dataclasses.replace(config, batch_size=2)\n",
|
||||
"\n",
|
||||
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
||||
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
||||
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
||||
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
||||
"obs, act = next(iter(loader))\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"\n",
|
||||
"# Delete the model to free up memory.\n",
|
||||
"del model\n",
|
||||
"\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
59
examples/libero/Dockerfile
Normal file
59
examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py"]
|
||||
56
examples/libero/README.md
Normal file
56
examples/libero/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
This example requires git submodules to be initialized. Don't forget to run:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
export SERVER_ARGS="--env LIBERO"
|
||||
docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
If you follow the training instructions and hyperparameters in the `pi0_libero` and `pi0_fast_libero` configs, you should get results similar to the following:
|
||||
|
||||
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|
||||
|-------|---------------|---------------|-------------|-----------|---------|
|
||||
| π0-FAST @ 30k (finetuned) | 96.4 | 96.8 | 88.6 | 60.2 | 85.5 |
|
||||
| π0 @ 30k (finetuned) | 96.8 | 98.8 | 95.8 | 85.2 | 94.15 |
|
||||
|
||||
Note that the hyperparameters for these runs are not tuned and $\pi_0$-FAST does not use a FAST tokenizer optimized for Libero. Likely, the results could be improved with more tuning, we mainly use these results as an example of how to use openpi to fine-tune $\pi_0$ models on a new dataset.
|
||||
52
examples/libero/compose.yml
Normal file
52
examples/libero/compose.yml
Normal file
@@ -0,0 +1,52 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- DISPLAY=$DISPLAY
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
106
examples/libero/convert_libero_data_to_lerobot.py
Normal file
106
examples/libero/convert_libero_data_to_lerobot.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset to LeRobot format.
|
||||
|
||||
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
|
||||
modified for any other data you have saved in a custom format.
|
||||
|
||||
Usage:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
Note: to run the script, you need to install tensorflow_datasets:
|
||||
`uv pip install tensorflow tensorflow_datasets`
|
||||
|
||||
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
|
||||
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
||||
Running this conversion script will take approximately 30 minutes.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import tensorflow_datasets as tfds
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
RAW_DATASET_NAMES = [
|
||||
"libero_10_no_noops",
|
||||
"libero_goal_no_noops",
|
||||
"libero_object_no_noops",
|
||||
"libero_spatial_no_noops",
|
||||
] # For simplicity we will combine multiple Libero datasets into one training dataset
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# OpenPi assumes that proprio is stored in `state` and actions in `action`
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=10,
|
||||
features={
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": ["state"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
|
||||
# You can modify this for your own data format
|
||||
for raw_dataset_name in RAW_DATASET_NAMES:
|
||||
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
|
||||
for episode in raw_dataset:
|
||||
for step in episode["steps"].as_numpy_iterator():
|
||||
dataset.add_frame(
|
||||
{
|
||||
"image": step["observation"]["image"],
|
||||
"wrist_image": step["observation"]["wrist_image"],
|
||||
"state": step["observation"]["state"],
|
||||
"actions": step["action"],
|
||||
}
|
||||
)
|
||||
dataset.save_episode(task=step["language_instruction"].decode())
|
||||
|
||||
# Consolidate the dataset, skip computing stats since we will do that later
|
||||
dataset.consolidate(run_compute_stats=False)
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
219
examples/libero/main.py
Normal file
219
examples/libero/main.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
)
|
||||
wrist_img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
examples/libero/requirements.in
Normal file
11
examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
examples/libero/requirements.txt
Normal file
136
examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
examples/policy_records.ipynb
Normal file
134
examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
32
examples/simple_client/Dockerfile
Normal file
32
examples/simple_client/Dockerfile
Normal file
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.7 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
examples/simple_client/README.md
Normal file
30
examples/simple_client/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specifiy which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
42
examples/simple_client/compose.yml
Normal file
42
examples/simple_client/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
89
examples/simple_client/main.py
Normal file
89
examples/simple_client/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tyro
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
num_steps: int = 10
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
# Send 1 observation to make sure the model is loaded.
|
||||
policy.infer(obs_fn())
|
||||
|
||||
start = time.time()
|
||||
for _ in range(args.num_steps):
|
||||
policy.infer(obs_fn())
|
||||
end = time.time()
|
||||
|
||||
print(f"Total time taken: {end - start:.2f} s")
|
||||
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
2
examples/simple_client/requirements.in
Normal file
2
examples/simple_client/requirements.in
Normal file
@@ -0,0 +1,2 @@
|
||||
numpy
|
||||
tyro
|
||||
27
examples/simple_client/requirements.txt
Normal file
27
examples/simple_client/requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.7
|
||||
backports-cached-property==1.0.2
|
||||
# via tyro
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
eval-type-backport==0.1.3
|
||||
# via tyro
|
||||
markdown-it-py==2.2.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.21.6
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
rich==13.8.1
|
||||
# via tyro
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
typing-extensions==4.7.1
|
||||
# via
|
||||
# markdown-it-py
|
||||
# rich
|
||||
# tyro
|
||||
tyro==0.9.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
25
packages/openpi-client/pyproject.toml
Normal file
25
packages/openpi-client/pyproject.toml
Normal file
@@ -0,0 +1,25 @@
|
||||
[project]
|
||||
name = "openpi-client"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.7"
|
||||
dependencies = [
|
||||
"dm-tree>=0.1.8",
|
||||
"msgpack>=1.0.5",
|
||||
"numpy>=1.21.6",
|
||||
"pillow>=9.0.0",
|
||||
"tree>=0.2.4",
|
||||
"websockets>=11.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"pytest>=8.3.4",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py37"
|
||||
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
1
packages/openpi-client/src/openpi_client/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
|
||||
|
||||
class ActionChunkBroker(_base_policy.BasePolicy):
|
||||
"""Wraps a policy to return action chunks one-at-a-time.
|
||||
|
||||
Assumes that the first dimension of all action fields is the chunk size.
|
||||
|
||||
A new inference call to the inner policy is only made when the current
|
||||
list of chunks is exhausted.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
||||
self._policy = policy
|
||||
|
||||
self._action_horizon = action_horizon
|
||||
self._cur_step: int = 0
|
||||
|
||||
self._last_results: Dict[str, np.ndarray] | None = None
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
if self._last_results is None:
|
||||
self._last_results = self._policy.infer(obs)
|
||||
self._cur_step = 0
|
||||
|
||||
results = tree.map_structure(lambda x: x[self._cur_step, ...], self._last_results)
|
||||
self._cur_step += 1
|
||||
|
||||
if self._cur_step >= self._action_horizon:
|
||||
self._last_results = None
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
self._last_results = None
|
||||
self._cur_step = 0
|
||||
12
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
12
packages/openpi-client/src/openpi_client/base_policy.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BasePolicy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def infer(self, obs: Dict) -> Dict:
|
||||
"""Infer actions from observations."""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy to its initial state."""
|
||||
pass
|
||||
58
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
58
packages/openpi-client/src/openpi_client/image_tools.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
||||
"""Converts an image to uint8 if it is a float image.
|
||||
|
||||
This is important for reducing the size of the image when sending it over the network.
|
||||
"""
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
||||
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
||||
|
||||
Args:
|
||||
images: A batch of images in [..., height, width, channel] format.
|
||||
height: The target height of the image.
|
||||
width: The target width of the image.
|
||||
method: The interpolation method to use. Default is bilinear.
|
||||
|
||||
Returns:
|
||||
The resized images in [..., height, width, channel].
|
||||
"""
|
||||
# If the images are already the correct size, return them as is.
|
||||
if images.shape[-3:-1] == (height, width):
|
||||
return images
|
||||
|
||||
original_shape = images.shape
|
||||
|
||||
images = images.reshape(-1, *original_shape[-3:])
|
||||
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
||||
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
||||
|
||||
|
||||
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
||||
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
||||
width without distortion by padding with zeros.
|
||||
|
||||
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
||||
"""
|
||||
cur_width, cur_height = image.size
|
||||
if cur_width == width and cur_height == height:
|
||||
return image # No need to resize if the image is already the correct size.
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_image = image.resize((resized_width, resized_height), resample=method)
|
||||
|
||||
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
||||
pad_height = max(0, int((height - resized_height) / 2))
|
||||
pad_width = max(0, int((width - resized_width) / 2))
|
||||
zero_image.paste(resized_image, (pad_width, pad_height))
|
||||
assert zero_image.size == (width, height)
|
||||
return zero_image
|
||||
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
37
packages/openpi-client/src/openpi_client/image_tools_test.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi_client.image_tools as image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
57
packages/openpi-client/src/openpi_client/msgpack_numpy.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Adds NumPy array support to msgpack.
|
||||
|
||||
msgpack is good for (de)serializing data over a network for multiple reasons:
|
||||
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
||||
- msgpack is widely used and has good cross-language support
|
||||
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
||||
languages like Python and JavaScript
|
||||
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
||||
than pickle for serializing large arrays using the below strategy
|
||||
|
||||
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
||||
that it falls back to pickle for object arrays.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
||||
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tree
|
||||
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
def _check(expected, actual):
|
||||
if isinstance(expected, np.ndarray):
|
||||
assert expected.shape == actual.shape
|
||||
assert expected.dtype == actual.dtype
|
||||
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
||||
else:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
1, # int
|
||||
1.0, # float
|
||||
"hello", # string
|
||||
np.bool_(True), # boolean scalar
|
||||
np.array([1, 2, 3])[0], # int scalar
|
||||
np.str_("asdf"), # string scalar
|
||||
[1, 2, 3], # list
|
||||
{"key": "value"}, # dict
|
||||
{"key": [1, 2, 3]}, # nested dict
|
||||
np.array(1.0), # 0D array
|
||||
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
||||
np.array(["asdf", "qwer"]), # string array
|
||||
np.array([True, False]), # boolean array
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
||||
np.array([np.nan, np.inf, -np.inf]), # special float values
|
||||
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
||||
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
||||
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
||||
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
||||
],
|
||||
)
|
||||
def test_pack_unpack(data):
|
||||
packed = msgpack_numpy.packb(data)
|
||||
unpacked = msgpack_numpy.unpackb(packed)
|
||||
tree.map_structure(_check, data, unpacked)
|
||||
17
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
17
packages/openpi-client/src/openpi_client/runtime/agent.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
||||
|
||||
Agents receive observations about the state of the world, and return actions
|
||||
to take in response.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
"""Query the agent for the next action."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the agent to its initial state."""
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client.runtime import agent as _agent
|
||||
|
||||
|
||||
class PolicyAgent(_agent.Agent):
|
||||
"""An agent that uses a policy to determine actions."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
||||
self._policy = policy
|
||||
|
||||
@override
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
return self._policy.infer(observation)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
@@ -0,0 +1,32 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
"""An Environment represents the robot and the environment it inhabits.
|
||||
|
||||
The primary contract of environments is that they can be queried for observations
|
||||
about their state, and have actions applied to them to change that state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the environment to its initial state.
|
||||
|
||||
This will be called once before starting each episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_episode_complete(self) -> bool:
|
||||
"""Allow the environment to signal that the episode is complete.
|
||||
|
||||
This will be called after each step. It should return `True` if the episode is
|
||||
complete (either successfully or unsuccessfully), and `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict:
|
||||
"""Query the environment for the current state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_action(self, action: dict) -> None:
|
||||
"""Take an action in the environment."""
|
||||
92
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
92
packages/openpi-client/src/openpi_client/runtime/runtime.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The core module orchestrating interactions between key components of the system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: _environment.Environment,
|
||||
agent: _agent.Agent,
|
||||
subscribers: list[_subscriber.Subscriber],
|
||||
max_hz: float = 0,
|
||||
num_episodes: int = 1,
|
||||
max_episode_steps: int = 0,
|
||||
) -> None:
|
||||
self._environment = environment
|
||||
self._agent = agent
|
||||
self._subscribers = subscribers
|
||||
self._max_hz = max_hz
|
||||
self._num_episodes = num_episodes
|
||||
self._max_episode_steps = max_episode_steps
|
||||
|
||||
self._in_episode = False
|
||||
self._episode_steps = 0
|
||||
|
||||
def run(self) -> None:
|
||||
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
||||
for _ in range(self._num_episodes):
|
||||
self._run_episode()
|
||||
|
||||
# Final reset, this is important for real environments to move the robot to its home position.
|
||||
self._environment.reset()
|
||||
|
||||
def run_in_new_thread(self) -> threading.Thread:
|
||||
"""Runs the runtime loop in a new thread."""
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def mark_episode_complete(self) -> None:
|
||||
"""Marks the end of an episode."""
|
||||
self._in_episode = False
|
||||
|
||||
def _run_episode(self) -> None:
|
||||
"""Runs a single episode."""
|
||||
logging.info("Starting episode...")
|
||||
self._environment.reset()
|
||||
self._agent.reset()
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_start()
|
||||
|
||||
self._in_episode = True
|
||||
self._episode_steps = 0
|
||||
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
||||
last_step_time = time.time()
|
||||
|
||||
while self._in_episode:
|
||||
self._step()
|
||||
self._episode_steps += 1
|
||||
|
||||
# Sleep to maintain the desired frame rate
|
||||
now = time.time()
|
||||
dt = now - last_step_time
|
||||
if dt < step_time:
|
||||
time.sleep(step_time - dt)
|
||||
last_step_time = time.time()
|
||||
else:
|
||||
last_step_time = now
|
||||
|
||||
logging.info("Episode completed.")
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_end()
|
||||
|
||||
def _step(self) -> None:
|
||||
"""A single step of the runtime loop."""
|
||||
observation = self._environment.get_observation()
|
||||
action = self._agent.get_action(observation)
|
||||
self._environment.apply_action(action)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_step(observation, action)
|
||||
|
||||
if self._environment.is_episode_complete() or (
|
||||
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
|
||||
):
|
||||
self.mark_episode_complete()
|
||||
@@ -0,0 +1,20 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Subscriber(abc.ABC):
|
||||
"""Subscribes to events in the runtime.
|
||||
|
||||
Subscribers can be used to save data, visualize, etc.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_start(self) -> None:
|
||||
"""Called when an episode starts."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
"""Append a step to the episode."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_end(self) -> None:
|
||||
"""Called when an episode ends."""
|
||||
@@ -0,0 +1,49 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import websockets.sync.client
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
||||
"""Implements the Policy interface by communicating with a server over websocket.
|
||||
|
||||
See WebsocketPolicyServer for a corresponding server implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000) -> None:
|
||||
self._uri = f"ws://{host}:{port}"
|
||||
self._packer = msgpack_numpy.Packer()
|
||||
self._ws, self._server_metadata = self._wait_for_server()
|
||||
|
||||
def get_server_metadata(self) -> Dict:
|
||||
return self._server_metadata
|
||||
|
||||
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
||||
logging.info(f"Waiting for server at {self._uri}...")
|
||||
while True:
|
||||
try:
|
||||
conn = websockets.sync.client.connect(self._uri, compression=None, max_size=None)
|
||||
metadata = msgpack_numpy.unpackb(conn.recv())
|
||||
return conn, metadata
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
data = self._packer.pack(obs)
|
||||
self._ws.send(data)
|
||||
response = self._ws.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
raise RuntimeError(f"Error in inference server:\n{response}")
|
||||
return msgpack_numpy.unpackb(response)
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
130
pyproject.toml
Normal file
130
pyproject.toml
Normal file
@@ -0,0 +1,130 @@
|
||||
[project]
|
||||
name = "openpi"
|
||||
version = "0.1.0"
|
||||
description = "Physical Intelligence open source repo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"augmax>=0.3.4",
|
||||
"dm-tree>=0.1.8",
|
||||
"einops>=0.8.0",
|
||||
"equinox>=0.11.8",
|
||||
"flatbuffers>=24.3.25",
|
||||
"flax==0.10.2",
|
||||
"fsspec[gcs]>=2024.6.0",
|
||||
"gym-aloha>=0.1.1",
|
||||
"imageio>=2.36.1",
|
||||
"jax[cuda12]==0.5.0",
|
||||
"jaxtyping==0.2.36",
|
||||
"lerobot",
|
||||
"ml_collections==1.0.0",
|
||||
"numpy>=1.26.4",
|
||||
"numpydantic>=1.6.6",
|
||||
"opencv-python>=4.10.0.84",
|
||||
"openpi-client",
|
||||
"orbax-checkpoint==0.11.1",
|
||||
"pillow>=11.0.0",
|
||||
"s3fs>=2024.9.0",
|
||||
"sentencepiece>=0.2.0",
|
||||
"torch>=2.5.1",
|
||||
"tqdm-loggable>=0.2",
|
||||
"typing-extensions>=4.12.2",
|
||||
"tyro>=0.9.5",
|
||||
"wandb>=0.19.1",
|
||||
"boto3>=1.35.7",
|
||||
"types-boto3[boto3,s3]>=1.35.7",
|
||||
"filelock>=3.16.1",
|
||||
"beartype>=0.19.0",
|
||||
"treescope>=0.1.7",
|
||||
"transformers==4.48.1",
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/Physical-Intelligence/openpi"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.3.4",
|
||||
"ruff>=0.8.6",
|
||||
"pre-commit>=4.0.1",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipywidgets>=8.1.5",
|
||||
"matplotlib>=3.10.0",
|
||||
"pynvml>=12.0.0",
|
||||
]
|
||||
|
||||
|
||||
[tool.uv.sources]
|
||||
openpi-client = { workspace = true }
|
||||
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "6674e368249472c91382eb54bb8501c94c7f0c56" }
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
extend-exclude = ["docker", "third_party"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
# https://docs.astral.sh/ruff/rules/
|
||||
select = [
|
||||
"B",
|
||||
"C4",
|
||||
"DTZ",
|
||||
"E4",
|
||||
"E7",
|
||||
"E9",
|
||||
"F",
|
||||
"FBT",
|
||||
"FURB",
|
||||
"I",
|
||||
"ICN",
|
||||
"ISC",
|
||||
"LOG",
|
||||
"N",
|
||||
"PD",
|
||||
"PERF",
|
||||
"PIE",
|
||||
"PLC",
|
||||
"PLE",
|
||||
"PLR1",
|
||||
"PLR5",
|
||||
"PLW",
|
||||
"PT",
|
||||
"PTH",
|
||||
"Q",
|
||||
"RET",
|
||||
"RUF",
|
||||
"SIM",
|
||||
"SLF",
|
||||
"T10",
|
||||
"T20",
|
||||
"UP",
|
||||
"W",
|
||||
]
|
||||
ignore = [
|
||||
"F722", # Conflicts with array typing.
|
||||
"T201", # We use print statements.
|
||||
"PD008", # Lots of false positives.
|
||||
"ISC001", # Disabling to support ruff format.
|
||||
]
|
||||
unfixable = [
|
||||
"B905", # Fix defaults to strict=False, which is not what we want.
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
force-single-line = true
|
||||
force-sort-within-sections = true
|
||||
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
|
||||
known-third-party = ["wandb"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["manual: should be run manually."]
|
||||
testpaths = ["src", "scripts", "packages"]
|
||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
75
scripts/compute_norm_stats.py
Normal file
75
scripts/compute_norm_stats.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Compute normalization statistics for a config.
|
||||
|
||||
This script is used to compute the normalization statistics for a given config. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_dataset(config: _config.TrainConfig) -> tuple[_config.DataConfig, _data_loader.Dataset]:
|
||||
data_config = config.data.create(config.assets_dirs, config.model)
|
||||
if data_config.repo_id is None:
|
||||
raise ValueError("Data config must have a repo_id")
|
||||
dataset = _data_loader.create_dataset(data_config, config.model)
|
||||
dataset = _data_loader.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config.repack_transforms.inputs,
|
||||
*data_config.data_transforms.inputs,
|
||||
# Remove strings since they are not supported by JAX and are not needed to compute norm stats.
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
return data_config, dataset
|
||||
|
||||
|
||||
def main(config_name: str, max_frames: int | None = None):
|
||||
config = _config.get_config(config_name)
|
||||
data_config, dataset = create_dataset(config)
|
||||
|
||||
num_frames = len(dataset)
|
||||
shuffle = False
|
||||
|
||||
if max_frames is not None and max_frames < num_frames:
|
||||
num_frames = max_frames
|
||||
shuffle = True
|
||||
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=1,
|
||||
num_workers=8,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_frames,
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
for batch in tqdm.tqdm(data_loader, total=num_frames, desc="Computing stats"):
|
||||
for key in keys:
|
||||
values = np.asarray(batch[key][0])
|
||||
stats[key].update(values.reshape(-1, values.shape[-1]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
output_path = config.assets_dirs / data_config.repo_id
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
29
scripts/docker/compose.yml
Normal file
29
scripts/docker/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
37
scripts/docker/install_docker_ubuntu22.sh
Executable file
37
scripts/docker/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
17
scripts/docker/install_nvidia_container_toolkit.sh
Executable file
17
scripts/docker/install_nvidia_container_toolkit.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
34
scripts/docker/serve_policy.Dockerfile
Normal file
34
scripts/docker/serve_policy.Dockerfile
Normal file
@@ -0,0 +1,34 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
122
scripts/serve_policy.py
Normal file
122
scripts/serve_policy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Default:
|
||||
"""Use the default policy for the given environment."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for. This is only used when serving default policies.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
||||
# prompt.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
||||
|
||||
|
||||
# Default checkpoints that should be used for each environment.
|
||||
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
||||
EnvMode.ALOHA: Checkpoint(
|
||||
config="pi0_aloha",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_base",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Checkpoint(
|
||||
config="pi0_aloha_sim",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_aloha_sim",
|
||||
),
|
||||
EnvMode.DROID: Checkpoint(
|
||||
config="pi0_fast_droid",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_fast_droid",
|
||||
),
|
||||
EnvMode.LIBERO: Checkpoint(
|
||||
config="pi0_fast_libero",
|
||||
dir="s3://openpi-assets/checkpoints/pi0_fast_libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
||||
"""Create a default policy for the given environment."""
|
||||
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
||||
)
|
||||
raise ValueError(f"Unsupported environment mode: {env}")
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
"""Create a policy from the given arguments."""
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
||||
)
|
||||
case Default():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
policy_metadata = policy.metadata
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
||||
|
||||
server = websocket_policy_server.WebsocketPolicyServer(
|
||||
policy=policy,
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
metadata=policy_metadata,
|
||||
)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
274
scripts/train.py
Normal file
274
scripts/train.py
Normal file
@@ -0,0 +1,274 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_threefry_partitionable", True) # noqa: FBT003
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
data_loader = _data_loader.create_data_loader(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
num_workers=config.num_workers,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
30
scripts/train_test.py
Normal file
30
scripts/train_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=tmp_path / "checkpoint",
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
0
src/openpi/__init__.py
Normal file
0
src/openpi/__init__.py
Normal file
17
src/openpi/conftest.py
Normal file
17
src/openpi/conftest.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
|
||||
import pynvml
|
||||
import pytest
|
||||
|
||||
|
||||
def set_jax_cpu_backend_if_no_gpu() -> None:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
pynvml.nvmlShutdown()
|
||||
except pynvml.NVMLError:
|
||||
# No GPU found.
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
set_jax_cpu_backend_if_no_gpu()
|
||||
0
src/openpi/models/__init__.py
Normal file
0
src/openpi/models/__init__.py
Normal file
426
src/openpi/models/gemma.py
Normal file
426
src/openpi/models/gemma.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gemma adaptation for Pi, taken from big_vision.
|
||||
|
||||
We follow this einsum axis naming convention:
|
||||
B: batch
|
||||
T: query length
|
||||
S: k/v length
|
||||
N: num query heads
|
||||
K: num k/v heads
|
||||
G: num query heads per k/v head
|
||||
H: head dim
|
||||
D: d_model ("features")
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257_152
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
Variant = Literal["dummy", "gemma_300m", "gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant: Variant) -> Config:
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "dummy":
|
||||
return Config(
|
||||
width=64,
|
||||
depth=4,
|
||||
mlp_dim=128,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=16,
|
||||
)
|
||||
if variant == "gemma_300m":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
|
||||
)
|
||||
if variant == "gemma_300m_lora":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype) # return in original dtype
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, positions, attn_mask, kv_cache):
|
||||
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
||||
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
||||
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
||||
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
||||
|
||||
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
||||
|
||||
qkvs = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is None:
|
||||
continue
|
||||
if config.num_kv_heads == config.num_heads:
|
||||
qkv_einsum = lora.Einsum(
|
||||
shape=(3, config.num_heads, config.width, config.head_dim),
|
||||
name=_name("qkv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
||||
else:
|
||||
q_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.width, config.head_dim),
|
||||
name=_name("q_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
q = q_einsum("BTD,NDH->BTNH", x)
|
||||
kv_einsum = lora.Einsum(
|
||||
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
||||
name=_name("kv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
qkvs.append((q, k, v))
|
||||
|
||||
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
||||
|
||||
q = _apply_rope(q, positions=positions)
|
||||
q *= self.configs[0].head_dim ** -0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions)
|
||||
|
||||
# should still be half-precision here (if input was half-precision)
|
||||
assert q.dtype == k.dtype == v.dtype == dtype
|
||||
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
k = jnp.concatenate([cache_k, k], axis=1)
|
||||
v = jnp.concatenate([cache_v, v], axis=1)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
|
||||
out = []
|
||||
start = 0
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
end = start + x.shape[1]
|
||||
out_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.head_dim, config.width),
|
||||
name=_name("attn_vec_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
||||
start = end
|
||||
else:
|
||||
out.append(None)
|
||||
|
||||
return out, (k, v)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
||||
|
||||
attn = Attention(configs=self.configs, name="attn")
|
||||
|
||||
pre_attn = []
|
||||
for i, x in enumerate(xs):
|
||||
if x is not None:
|
||||
x = RMSNorm(name=_name("pre_attention_norm", i))(x) # noqa: PLW2901
|
||||
pre_attn.append(x)
|
||||
|
||||
pre_attn = sharding.activation_sharding_constraint(pre_attn)
|
||||
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
|
||||
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
||||
post_attn = sharding.activation_sharding_constraint(post_attn)
|
||||
xs = jax.tree.map(lambda x, y: x + y, xs, post_attn)
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
out = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
x = RMSNorm(name=_name("pre_ffw_norm", i))(x) # noqa: PLW2901
|
||||
x = lora.FeedForward( # noqa: PLW2901
|
||||
features=config.width,
|
||||
hidden_dim=config.mlp_dim,
|
||||
name=_name("mlp", i),
|
||||
lora_config=config.lora_configs.get("ffn"),
|
||||
)(x)
|
||||
out.append(x)
|
||||
|
||||
out = sharding.activation_sharding_constraint(out)
|
||||
|
||||
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
||||
xs = jax.tree.map(lambda x, y: x + y, xs, out)
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
return xs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
||||
|
||||
configs: Sequence[Config] # list of configs, one for each expert
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
|
||||
def setup(self):
|
||||
# all experts must have the same depth
|
||||
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
||||
|
||||
self.embedder = Embedder(
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
embed_dim=self.configs[0].width, # embedder for first expert only
|
||||
name="embedder",
|
||||
)
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=False,
|
||||
static_argnums=(5,), # 0=self, 5=deterministic
|
||||
policy=jax.checkpoint_policies.nothing_saveable,
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask, 3=decode
|
||||
length=self.configs[0].depth,
|
||||
)(
|
||||
configs=self.configs,
|
||||
dropout=self.dropout,
|
||||
dropout_bdims=self.dropout_bdims,
|
||||
)
|
||||
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
|
||||
|
||||
@at.typecheck
|
||||
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
|
||||
return self.embedder.encode(tokens).astype(self.embed_dtype)
|
||||
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
# list of token arrays, one for each expert, or None if that expert should not be run
|
||||
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
|
||||
positions: at.Int[at.Array, "b t"],
|
||||
mask: at.Bool[at.Array, "b t s"],
|
||||
*,
|
||||
kv_cache: KVCache | None = None,
|
||||
deterministic: bool = True,
|
||||
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
|
||||
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
||||
mask = jnp.asarray(mask)[:, None, :, :]
|
||||
|
||||
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, deterministic)
|
||||
|
||||
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
||||
|
||||
return [f(e) if e is not None else e for f, e in zip(self.final_norms, embedded, strict=True)], kv_cache
|
||||
|
||||
def init(self):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
self(
|
||||
[jnp.zeros((1, 1, c.width)) for c in self.configs],
|
||||
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
|
||||
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
|
||||
)
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
||||
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
||||
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
||||
# here.
|
||||
return res.astype(x.dtype)
|
||||
|
||||
|
||||
def _name(name, i):
|
||||
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
||||
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
||||
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
||||
# and the action expert.
|
||||
if i == 0:
|
||||
return name
|
||||
return f"{name}_{i}"
|
||||
427
src/openpi/models/gemma_fast.py
Normal file
427
src/openpi/models/gemma_fast.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "gemma_2b":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Einsum(nn.Module):
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
|
||||
return jnp.einsum(eqn, x, w)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype) # return in original dtype
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
features: int
|
||||
head_dim: int
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
)
|
||||
else:
|
||||
# MQA
|
||||
self.q_einsum = Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
)
|
||||
self.kv_einsum = Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
)
|
||||
self.attn_vec_einsum = Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
"""Initialize KV cache"""
|
||||
prefill_len = k.shape[1]
|
||||
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
|
||||
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
|
||||
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
|
||||
return idx, k_cache, v_cache
|
||||
|
||||
def _update_cache(self, k, v, idx, k_cache, v_cache):
|
||||
"""Update KV cache with new values"""
|
||||
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
|
||||
indices = (0, idx[0], 0, 0)
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
|
||||
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
|
||||
idx_new = idx + 1
|
||||
return idx_new, k_new, v_new
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
|
||||
else:
|
||||
q = self.q_einsum("BTD,NDH->BTNH", x)
|
||||
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
|
||||
q = _apply_rope(q, positions=positions) # promotes to float32
|
||||
q *= self.head_dim**-0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions) # promotes to float32
|
||||
|
||||
if kv_cache is None:
|
||||
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
|
||||
else:
|
||||
idx, k_cache, v_cache = kv_cache
|
||||
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
|
||||
|
||||
k, v = k_cache, v_cache
|
||||
kv_cache = (idx, k_cache, v_cache)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.zeros_init(),
|
||||
((2, self.features, self.hidden_dim)),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
embed_dim: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = FeedForward(features=self.embed_dim, hidden_dim=self.hidden_dim)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
self.drop = lambda x, _: x
|
||||
|
||||
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
||||
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
||||
inputs_normalized = self.pre_attention_norm(x)
|
||||
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
|
||||
attn_output = self.drop(attn_output, deterministic)
|
||||
attn_output += x
|
||||
residual = attn_output
|
||||
attn_output = self.pre_ffw_norm(attn_output)
|
||||
outputs = self.mlp(attn_output)
|
||||
outputs = self.drop(outputs, deterministic)
|
||||
outputs = residual + outputs
|
||||
return outputs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""gemma model."""
|
||||
|
||||
variant: str
|
||||
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
cache_dtype: str | None = None
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self,
|
||||
tokens=None,
|
||||
embedded_prefix=None,
|
||||
embed_only=False, # noqa: FBT002
|
||||
pre_logits=None,
|
||||
positions=None,
|
||||
mask=None,
|
||||
decode=False, # noqa: FBT002
|
||||
kv_cache=None,
|
||||
deterministic=True, # noqa: FBT002
|
||||
return_prelogits=False, # noqa: FBT002
|
||||
):
|
||||
"""Embed only, or complete forward pass.
|
||||
|
||||
Args:
|
||||
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
|
||||
embedded_prefix: Optional prefix that is already embedded.
|
||||
embed_only: Whether to compute embeddings only.
|
||||
pre_logits: If present computes logits from pre_logits and returns.
|
||||
positions: Optional `[B, T]` allows to specify the absolute position of
|
||||
the tokens.
|
||||
mask: Optional attention mask `[B, T, S]`.
|
||||
decode: Whether to use kv-cache. Caller must pass masks and positions.
|
||||
deterministic: Forwarded to all dropout layers.
|
||||
return_prelogits: Whether to return the pre-logits.
|
||||
|
||||
Returns:
|
||||
If `embed_only=False`, then `(logits, out)` will be returned.
|
||||
If `embed_only=True`, then the embeddings will be returned.
|
||||
If `return_prelogits=True`, then the pre-logits will be returned.
|
||||
"""
|
||||
out = {}
|
||||
|
||||
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
|
||||
|
||||
if pre_logits is not None:
|
||||
x = out["pre_logits"] = pre_logits
|
||||
logits = out["logits"] = embedder.decode(x)
|
||||
return logits, out
|
||||
|
||||
x = []
|
||||
if embedded_prefix is not None:
|
||||
x.append(embedded_prefix)
|
||||
if tokens is not None:
|
||||
x.append(embedder.encode(tokens))
|
||||
|
||||
x = jnp.concatenate(x, axis=-2)
|
||||
x = x.astype(self.embed_dtype)
|
||||
batch_size, seq_len, width = x.shape
|
||||
|
||||
if embed_only:
|
||||
return x
|
||||
|
||||
if decode:
|
||||
assert positions is not None and mask is not None, ( # noqa: PT018
|
||||
"Must explicitly pass positions and mask for decoding."
|
||||
)
|
||||
|
||||
if positions is None:
|
||||
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
|
||||
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
|
||||
|
||||
if mask is None:
|
||||
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
|
||||
if mask.ndim == 3:
|
||||
mask = mask[:, None, :, :]
|
||||
cache_size = max(seq_len, mask.shape[-1])
|
||||
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
|
||||
|
||||
if self.remat_policy == "none":
|
||||
block_cls = Block
|
||||
else:
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=not self.scan,
|
||||
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy),
|
||||
)
|
||||
|
||||
block_kw = {
|
||||
"num_heads": self.num_heads,
|
||||
"head_dim": self.head_dim,
|
||||
"num_kv_heads": self.num_kv_heads,
|
||||
"embed_dim": width,
|
||||
"hidden_dim": self.mlp_dim,
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
|
||||
length=self.depth,
|
||||
)(parent=layers, **block_kw)
|
||||
]
|
||||
for block in blocks:
|
||||
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
|
||||
|
||||
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
|
||||
out["encoded"] = x
|
||||
|
||||
x = RMSNorm(name="final_norm")(x)
|
||||
out["pre_logits"] = x
|
||||
if return_prelogits:
|
||||
return x, kv_cache, out
|
||||
|
||||
x = embedder.decode(x)
|
||||
out["logits"] = x
|
||||
|
||||
return x, kv_cache, out
|
||||
|
||||
def init(self):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
return res
|
||||
148
src/openpi/models/lora.py
Normal file
148
src/openpi/models/lora.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import math
|
||||
import re
|
||||
|
||||
import flax.linen as nn
|
||||
import flax.struct as struct
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
# LoRA rank.
|
||||
rank: int
|
||||
# LoRA scaling factor.
|
||||
alpha: float = 1.0
|
||||
# Initialization function for LoRA parameters.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
|
||||
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
|
||||
rslora: bool = False
|
||||
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
|
||||
axes: tuple[int, int] = (-2, -1)
|
||||
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
|
||||
label: str = "L"
|
||||
|
||||
@property
|
||||
def scaling_value(self) -> float:
|
||||
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
|
||||
|
||||
|
||||
class Einsum(nn.Module):
|
||||
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
|
||||
|
||||
# Shape of the weight.
|
||||
shape: tuple[int, ...]
|
||||
# Initialization function for the weight.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.zeros
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w = self.param("w", self.init_fn, self.shape)
|
||||
|
||||
if config := self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
shape_a, shape_b = list(self.shape), list(self.shape)
|
||||
shape_a[config.axes[1]] = config.rank
|
||||
shape_b[config.axes[0]] = config.rank
|
||||
self.w_a = self.param("lora_a", config.init_fn, shape_a)
|
||||
self.w_b = self.param("lora_b", config.init_fn, shape_b)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn: str, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
result = jnp.einsum(eqn, x, self.w.astype(dtype))
|
||||
|
||||
if config := self.lora_config:
|
||||
eqn_a, eqn_b = self._make_lora_eqns(eqn)
|
||||
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
|
||||
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
|
||||
result = result + lora * config.scaling_value
|
||||
|
||||
return result
|
||||
|
||||
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
|
||||
if "L" in eqn:
|
||||
raise ValueError(f"L already in eqn: {eqn}")
|
||||
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
|
||||
raise ValueError(f"Unsupported einsum eqn: {eqn}")
|
||||
lhs, rhs, out = m.groups()
|
||||
|
||||
assert self.lora_config is not None
|
||||
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
|
||||
label = self.lora_config.label
|
||||
|
||||
a_rhs = rhs.replace(b_label, label)
|
||||
a_out = out.replace(b_label, label)
|
||||
eqn_a = f"{lhs},{a_rhs}->{a_out}"
|
||||
|
||||
b_rhs = rhs.replace(a_label, label)
|
||||
eqn_b = f"{a_out},{b_rhs}->{out}"
|
||||
|
||||
return eqn_a, eqn_b
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
)
|
||||
self.w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
)
|
||||
self.w_gating_lora = None
|
||||
self.w_linear_lora = None
|
||||
if self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
# TODO: follow up with a simplified init_fn api.
|
||||
self.w_gating_lora = (
|
||||
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
|
||||
self.param(
|
||||
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
|
||||
),
|
||||
)
|
||||
self.w_linear_lora = (
|
||||
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
|
||||
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
|
||||
)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
ff_gate = self._dot(
|
||||
x,
|
||||
self.w_gating[0],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
|
||||
)
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = self._dot(
|
||||
x,
|
||||
self.w_gating[1],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
|
||||
)
|
||||
activations = gate_value * ff1
|
||||
|
||||
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
|
||||
base = jnp.dot(x, w.astype(x.dtype))
|
||||
if lora_weights is None:
|
||||
return base
|
||||
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
|
||||
94
src/openpi/models/lora_test.py
Normal file
94
src/openpi/models/lora_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
|
||||
|
||||
def test_lora_einsum_params_shape():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
|
||||
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
# Ensure that lora parameters are not initialized when LoRA is not used.
|
||||
params = einsum.init(key, eqn, x)
|
||||
assert "lora_a" not in params["params"]
|
||||
assert "lora_b" not in params["params"]
|
||||
|
||||
# Check that default axes work.
|
||||
params_lora0 = lora0.init(key, eqn, x)
|
||||
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
|
||||
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
|
||||
|
||||
# Check that user provided axes work.
|
||||
params_lora1 = lora1.init(key, eqn, x)
|
||||
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
|
||||
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
|
||||
|
||||
|
||||
def test_lora_einsum_same_output():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
params = einsum.init(key, eqn, x)
|
||||
output = einsum.apply(params, eqn, x)
|
||||
|
||||
params_lora = einsum_lora.init(key, eqn, x)
|
||||
output_lora = einsum_lora.apply(params_lora, eqn, x)
|
||||
|
||||
# Results are the same since the LoRA parameters are initialized to zeros.
|
||||
assert jnp.allclose(output, output_lora)
|
||||
|
||||
|
||||
def test_lora_ffn_params_shape():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params["params"]["linear"].shape == (32, 8)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params_lora["params"]["linear"].shape == (32, 8)
|
||||
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
|
||||
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
|
||||
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
|
||||
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
|
||||
|
||||
|
||||
def test_lora_ffn_same_output():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
output = ffn.apply(params, x)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
output_lora = ffn_lora.apply(params_lora, x)
|
||||
|
||||
assert jnp.allclose(output, output_lora)
|
||||
321
src/openpi/models/model.py
Normal file
321
src/openpi/models/model.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import augmax
|
||||
from flax import nnx
|
||||
from flax import struct
|
||||
from flax import traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
|
||||
from openpi.shared import image_tools
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
ArrayT = TypeVar("ArrayT", at.Array, jax.ShapeDtypeStruct)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
"""Supported model types."""
|
||||
|
||||
PI0 = "pi0"
|
||||
PI0_FAST = "pi0_fast"
|
||||
|
||||
|
||||
# The model always expects these images
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
|
||||
# This may need change if we release a small model.
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
# Data format
|
||||
#
|
||||
# Data transforms produce the model input as a nested dictionary which is later converted
|
||||
# into `Obesrvation` and `Actions` objects. See below.
|
||||
#
|
||||
# In the dictory form, this data should look like:
|
||||
# {
|
||||
# # Observation data.
|
||||
# "image": {
|
||||
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
|
||||
# ... # Additional camera views
|
||||
# },
|
||||
# "image_mask": {
|
||||
# "base_0_rgb": bool[*b], # True if image is valid
|
||||
# ... # Masks for additional views
|
||||
# },
|
||||
# "state": float32[*b, s], # Low-dimensional robot state
|
||||
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
|
||||
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
|
||||
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
|
||||
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
|
||||
#
|
||||
# # Actions data.
|
||||
# "actions": float32[*b ah ad]
|
||||
# }
|
||||
# where:
|
||||
# *b = batch dimensions
|
||||
# h,w = image height/width
|
||||
# s = state dimension
|
||||
# l = sequence length
|
||||
#
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class Observation(Generic[ArrayT]):
|
||||
"""Holds observations, i.e., inputs to the model.
|
||||
|
||||
See `Observation.from_dict` to see the expected dictionary form. This is the format
|
||||
that should be produced by the data transforms.
|
||||
"""
|
||||
|
||||
# Images, in [-1, 1] float32.
|
||||
images: dict[str, at.Float[ArrayT, "*b h w c"]]
|
||||
# Image masks, with same keys as images.
|
||||
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
|
||||
# Low-dimensional robot state.
|
||||
state: at.Float[ArrayT, "*b s"]
|
||||
|
||||
# Tokenized prompt.
|
||||
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Tokenized prompt mask.
|
||||
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
# pi0-fast model specific fields.
|
||||
|
||||
# Token auto-regressive mask (for FAST autoregressive model).
|
||||
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Token loss mask (for FAST autoregressive model).
|
||||
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
|
||||
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
|
||||
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
|
||||
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
|
||||
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
|
||||
# If images are uint8, convert them to [-1, 1] float32.
|
||||
for key in data["image"]:
|
||||
if data["image"][key].dtype == np.uint8:
|
||||
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
return cls(
|
||||
images=data["image"],
|
||||
image_masks=data["image_mask"],
|
||||
state=data["state"],
|
||||
tokenized_prompt=data.get("tokenized_prompt"),
|
||||
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
|
||||
token_ar_mask=data.get("token_ar_mask"),
|
||||
token_loss_mask=data.get("token_loss_mask"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> at.PyTree[ArrayT]:
|
||||
"""Convert the Observation to a nested dict."""
|
||||
result = dataclasses.asdict(self)
|
||||
result["image"] = result.pop("images")
|
||||
result["image_mask"] = result.pop("image_masks")
|
||||
return result
|
||||
|
||||
|
||||
# Defines the format of the actions. This field is included as "actions" inside the dictionary
|
||||
# produced by the data transforms.
|
||||
Actions = at.Float[ArrayT, "*b ah ad"]
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
rng: at.KeyArrayLike | None,
|
||||
observation: Observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
) -> Observation:
|
||||
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
|
||||
filling in a default image mask (if necessary).
|
||||
"""
|
||||
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for augmax.
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
transforms = []
|
||||
if "wrist" not in key:
|
||||
height, width = image.shape[1:3]
|
||||
transforms += [
|
||||
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
|
||||
augmax.Resize(width, height),
|
||||
augmax.Rotate((-5, 5)),
|
||||
]
|
||||
transforms += [
|
||||
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
|
||||
]
|
||||
sub_rngs = jax.random.split(rng, image.shape[0])
|
||||
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
|
||||
|
||||
# Back to [-1, 1].
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
|
||||
else:
|
||||
out_masks[key] = jnp.asarray(observation.image_masks[key])
|
||||
|
||||
return Observation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BaseModelConfig(abc.ABC):
|
||||
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
|
||||
method to create the corresponding model.
|
||||
"""
|
||||
|
||||
# Action space dimension.
|
||||
action_dim: int
|
||||
# Action sequence length.
|
||||
action_horizon: int
|
||||
# Tokenized prompt maximum length.
|
||||
max_token_len: int
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_type(self) -> ModelType:
|
||||
"""The model type."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
|
||||
"""Create a new model, initializing parameters."""
|
||||
|
||||
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
|
||||
"""Create a model with the given parameters."""
|
||||
model = nnx.eval_shape(self.create, jax.random.key(0))
|
||||
graphdef, state = nnx.split(model)
|
||||
if remove_extra_params:
|
||||
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
|
||||
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
|
||||
state.replace_by_pure_dict(params)
|
||||
return nnx.merge(graphdef, state)
|
||||
|
||||
@abc.abstractmethod
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
|
||||
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
|
||||
|
||||
def fake_obs(self, batch_size: int = 1) -> Observation:
|
||||
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
|
||||
|
||||
def fake_act(self, batch_size: int = 1) -> Actions:
|
||||
_, action_spec = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseModel(nnx.Module, abc.ABC):
|
||||
"""Base class for all model implementations. Specific models should inherit from this class. They should call
|
||||
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
|
||||
"""
|
||||
|
||||
action_dim: int
|
||||
action_horizon: int
|
||||
max_token_len: int
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: Observation,
|
||||
actions: Actions,
|
||||
*,
|
||||
train: bool = False,
|
||||
) -> at.Float[at.Array, "*b ah"]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation) -> Actions: ...
|
||||
|
||||
|
||||
def restore_params(
|
||||
params_path: pathlib.Path | str,
|
||||
*,
|
||||
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
|
||||
dtype: jnp.dtype | None = None,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
) -> at.Params:
|
||||
"""Restores unstructured params PyTree from a checkpoint.
|
||||
|
||||
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
|
||||
well as pre-trained checkpoints released for openpi.
|
||||
|
||||
Args:
|
||||
params_path: The local path to the checkpoint directory.
|
||||
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
|
||||
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
|
||||
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
|
||||
|
||||
Returns:
|
||||
The restored params.
|
||||
"""
|
||||
params_path = pathlib.Path(params_path).resolve()
|
||||
if not params_path.exists():
|
||||
raise FileNotFoundError(f"Model params not found at: {params_path}")
|
||||
|
||||
if restore_type is jax.Array and sharding is None:
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
with ocp.PyTreeCheckpointer() as ckptr:
|
||||
metadata = ckptr.metadata(params_path)
|
||||
item = {"params": metadata["params"]}
|
||||
|
||||
params = ckptr.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree.map(
|
||||
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
|
||||
),
|
||||
),
|
||||
)["params"]
|
||||
|
||||
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
|
||||
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
if all(kp[-1] == "value" for kp in flat_params):
|
||||
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
|
||||
return traverse_util.unflatten_dict(flat_params)
|
||||
72
src/openpi/models/model_test.py
Normal file
72
src/openpi/models/model_test.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.models import pi0
|
||||
from openpi.models import pi0_fast
|
||||
from openpi.shared import download
|
||||
from openpi.shared import nnx_utils
|
||||
|
||||
|
||||
def test_pi0_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0.Pi0Config()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0.Pi0Config(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_fast_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_model_restore():
|
||||
key = jax.random.key(0)
|
||||
config = pi0.Pi0Config()
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
model = config.load(
|
||||
_model.restore_params(download.maybe_download("s3://openpi-assets/checkpoints/pi0_base/params"))
|
||||
)
|
||||
|
||||
loss = model.compute_loss(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = model.sample_actions(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
325
src/openpi/models/pi0.py
Normal file
325
src/openpi/models/pi0.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
import einops
|
||||
import flax.nnx as nnx
|
||||
import flax.nnx.bridge as nnx_bridge
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import model as _model
|
||||
import openpi.models.gemma as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
|
||||
def make_attn_mask(input_mask, mask_ar):
|
||||
"""Adapted from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
||||
it and false where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
||||
cumsum = jnp.cumsum(mask_ar, axis=1)
|
||||
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
||||
return jnp.logical_and(attn_mask, valid_mask)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def posemb_sincos(
|
||||
pos: at.Real[at.Array, " b"], embedding_dim: int, min_period: float, max_period: float
|
||||
) -> at.Float[at.Array, "b {embedding_dim}"]:
|
||||
"""Computes sine-cosine positional embedding vectors for scalar positions."""
|
||||
if embedding_dim % 2 != 0:
|
||||
raise ValueError(f"embedding_dim ({embedding_dim}) must be divisible by 2")
|
||||
|
||||
fraction = jnp.linspace(0.0, 1.0, embedding_dim // 2)
|
||||
period = min_period * (max_period / min_period) ** fraction
|
||||
sinusoid_input = jnp.einsum(
|
||||
"i,j->ij",
|
||||
pos,
|
||||
1.0 / period * 2 * jnp.pi,
|
||||
precision=jax.lax.Precision.HIGHEST,
|
||||
)
|
||||
return jnp.concatenate([jnp.sin(sinusoid_input), jnp.cos(sinusoid_input)], axis=-1)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pi0Config(_model.BaseModelConfig):
|
||||
dtype: str = "bfloat16"
|
||||
paligemma_variant: _gemma.Variant = "gemma_2b"
|
||||
action_expert_variant: _gemma.Variant = "gemma_300m"
|
||||
|
||||
# Set the model specific defaults.
|
||||
action_dim: int = 32
|
||||
action_horizon: int = 50
|
||||
max_token_len: int = 48
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_type(self) -> _model.ModelType:
|
||||
return _model.ModelType.PI0
|
||||
|
||||
@override
|
||||
def create(self, rng: at.KeyArrayLike) -> "Pi0":
|
||||
return Pi0(self, rngs=nnx.Rngs(rng))
|
||||
|
||||
@override
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
||||
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
||||
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
||||
|
||||
with at.disable_typechecking():
|
||||
observation_spec = _model.Observation(
|
||||
images={
|
||||
"base_0_rgb": image_spec,
|
||||
"left_wrist_0_rgb": image_spec,
|
||||
"right_wrist_0_rgb": image_spec,
|
||||
},
|
||||
image_masks={
|
||||
"base_0_rgb": image_mask_spec,
|
||||
"left_wrist_0_rgb": image_mask_spec,
|
||||
"right_wrist_0_rgb": image_mask_spec,
|
||||
},
|
||||
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
||||
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
||||
)
|
||||
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
def get_freeze_filter(self) -> nnx.filterlib.Filter:
|
||||
"""Returns the freeze filter based on the model config."""
|
||||
filters = []
|
||||
has_lora = False
|
||||
gemma_params_filter = nnx_utils.PathRegex(".*llm.*")
|
||||
action_expert_params_filter = nnx_utils.PathRegex(".*llm.*_1.*")
|
||||
if "lora" in self.paligemma_variant:
|
||||
filters.append(
|
||||
gemma_params_filter,
|
||||
)
|
||||
if "lora" not in self.action_expert_variant:
|
||||
# If only freeze gemma params, exclude action expert params.
|
||||
filters.append(
|
||||
nnx.Not(action_expert_params_filter),
|
||||
)
|
||||
has_lora = True
|
||||
elif "lora" in self.action_expert_variant:
|
||||
filters.append(
|
||||
action_expert_params_filter,
|
||||
)
|
||||
has_lora = True
|
||||
|
||||
if has_lora:
|
||||
# If any lora is used, exclude all lora params.
|
||||
filters.append(
|
||||
nnx.Not(nnx_utils.PathRegex(".*lora.*")),
|
||||
)
|
||||
if not filters:
|
||||
return nnx.Nothing
|
||||
return nnx.All(*filters)
|
||||
|
||||
|
||||
class Pi0(_model.BaseModel):
|
||||
def __init__(self, config: Pi0Config, rngs: nnx.Rngs):
|
||||
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
action_expert_config = _gemma.get_config(config.action_expert_variant)
|
||||
# TODO: rewrite gemma in NNX. For now, use bridge.
|
||||
llm = nnx_bridge.ToNNX(
|
||||
_gemma.Module(
|
||||
configs=[paligemma_config, action_expert_config],
|
||||
embed_dtype=config.dtype,
|
||||
)
|
||||
)
|
||||
llm.lazy_init(rngs=rngs, method="init")
|
||||
img = nnx_bridge.ToNNX(
|
||||
_siglip.Module(
|
||||
num_classes=paligemma_config.width,
|
||||
variant="So400m/14",
|
||||
pool_type="none",
|
||||
scan=True,
|
||||
dtype_mm=config.dtype,
|
||||
)
|
||||
)
|
||||
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
||||
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
||||
self.state_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
||||
self.action_in_proj = nnx.Linear(config.action_dim, action_expert_config.width, rngs=rngs)
|
||||
self.action_time_mlp_in = nnx.Linear(2 * action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
self.action_time_mlp_out = nnx.Linear(action_expert_config.width, action_expert_config.width, rngs=rngs)
|
||||
self.action_out_proj = nnx.Linear(action_expert_config.width, config.action_dim, rngs=rngs)
|
||||
|
||||
@at.typecheck
|
||||
def embed_prefix(
|
||||
self, obs: _model.Observation
|
||||
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
tokens = []
|
||||
# embed images
|
||||
for name in obs.images:
|
||||
image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False)
|
||||
|
||||
tokens.append(image_tokens)
|
||||
input_mask.append(
|
||||
einops.repeat(
|
||||
obs.image_masks[name],
|
||||
"b -> b s",
|
||||
s=image_tokens.shape[1],
|
||||
)
|
||||
)
|
||||
# image tokens attend to each other
|
||||
ar_mask += [False] * image_tokens.shape[1]
|
||||
|
||||
# add language (aka tokenized inputs)
|
||||
if obs.tokenized_prompt is not None:
|
||||
tokenized_inputs = self.PaliGemma.llm(obs.tokenized_prompt, method="embed")
|
||||
tokens.append(tokenized_inputs)
|
||||
input_mask.append(obs.tokenized_prompt_mask)
|
||||
# full attention between image and language inputs
|
||||
ar_mask += [False] * tokenized_inputs.shape[1]
|
||||
tokens = jnp.concatenate(tokens, axis=1)
|
||||
input_mask = jnp.concatenate(input_mask, axis=1)
|
||||
ar_mask = jnp.array(ar_mask)
|
||||
return tokens, input_mask, ar_mask
|
||||
|
||||
@at.typecheck
|
||||
def embed_suffix(
|
||||
self, obs: _model.Observation, noisy_actions: _model.Actions, timestep: at.Float[at.Array, " b"]
|
||||
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Bool[at.Array, " s"]]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
tokens = []
|
||||
# add a single state token
|
||||
state_token = self.state_proj(obs.state)[:, None, :]
|
||||
tokens.append(state_token)
|
||||
input_mask.append(jnp.ones((obs.state.shape[0], 1), dtype=jnp.bool_))
|
||||
# image/language inputs do not attend to state or actions
|
||||
ar_mask += [True]
|
||||
|
||||
# embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
|
||||
time_emb = posemb_sincos(timestep, self.action_in_proj.out_features, min_period=4e-3, max_period=4.0)
|
||||
# mix timestep + action information using an MLP
|
||||
action_tokens = self.action_in_proj(noisy_actions)
|
||||
time_tokens = einops.repeat(time_emb, "b emb -> b s emb", s=self.action_horizon)
|
||||
action_time_tokens = jnp.concatenate([action_tokens, time_tokens], axis=-1)
|
||||
action_time_tokens = self.action_time_mlp_in(action_time_tokens)
|
||||
action_time_tokens = nnx.swish(action_time_tokens)
|
||||
action_time_tokens = self.action_time_mlp_out(action_time_tokens)
|
||||
tokens.append(action_time_tokens)
|
||||
input_mask.append(jnp.ones(action_time_tokens.shape[:2], dtype=jnp.bool_))
|
||||
# image/language/state inputs do not attend to action tokens
|
||||
ar_mask += [True] + ([False] * (self.action_horizon - 1))
|
||||
tokens = jnp.concatenate(tokens, axis=1)
|
||||
input_mask = jnp.concatenate(input_mask, axis=1)
|
||||
ar_mask = jnp.array(ar_mask)
|
||||
return tokens, input_mask, ar_mask
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
||||
) -> at.Float[at.Array, "*b ah"]:
|
||||
preprocess_rng, noise_rng, time_rng = jax.random.split(rng, 3)
|
||||
observation = _model.preprocess_observation(preprocess_rng, observation, train=train)
|
||||
|
||||
batch_shape = actions.shape[:-2]
|
||||
noise = jax.random.normal(noise_rng, actions.shape)
|
||||
time = jax.random.beta(time_rng, 1.5, 1, batch_shape) * 0.999 + 0.001
|
||||
time_expanded = time[..., None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
u_t = noise - actions
|
||||
|
||||
# one big forward pass of prefix + suffix at once
|
||||
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
||||
suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(observation, x_t, time)
|
||||
input_mask = jnp.concatenate([prefix_mask, suffix_mask], axis=1)
|
||||
ar_mask = jnp.concatenate([prefix_ar_mask, suffix_ar_mask], axis=0)
|
||||
attn_mask = make_attn_mask(input_mask, ar_mask)
|
||||
positions = jnp.cumsum(input_mask, axis=1) - 1
|
||||
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
||||
[prefix_tokens, suffix_tokens], mask=attn_mask, positions=positions
|
||||
)
|
||||
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
||||
|
||||
return jnp.mean(jnp.square(v_t - u_t), axis=-1)
|
||||
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: _model.Observation,
|
||||
*,
|
||||
num_steps: int | at.Int[at.Array, ""] = 10,
|
||||
) -> _model.Actions:
|
||||
observation = _model.preprocess_observation(None, observation, train=False)
|
||||
# note that we use the convention more common in diffusion literature, where t=1 is noise and t=0 is the target
|
||||
# distribution. yes, this is the opposite of the pi0 paper, and I'm sorry.
|
||||
dt = -1.0 / num_steps
|
||||
batch_size = observation.state.shape[0]
|
||||
noise = jax.random.normal(rng, (batch_size, self.action_horizon, self.action_dim))
|
||||
|
||||
# first fill KV cache with a forward pass of the prefix
|
||||
prefix_tokens, prefix_mask, prefix_ar_mask = self.embed_prefix(observation)
|
||||
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
||||
positions = jnp.cumsum(prefix_mask, axis=1) - 1
|
||||
_, kv_cache = self.PaliGemma.llm([prefix_tokens, None], mask=prefix_attn_mask, positions=positions)
|
||||
|
||||
def step(carry):
|
||||
x_t, time = carry
|
||||
suffix_tokens, suffix_mask, suffix_ar_mask = self.embed_suffix(
|
||||
observation, x_t, jnp.broadcast_to(time, batch_size)
|
||||
)
|
||||
# `suffix_attn_mask` is shape (b, suffix_len, suffix_len) indicating how the suffix tokens can attend to each
|
||||
# other
|
||||
suffix_attn_mask = make_attn_mask(suffix_mask, suffix_ar_mask)
|
||||
# `prefix_attn_mask` is shape (b, suffix_len, prefix_len) indicating how the suffix tokens can attend to the
|
||||
# prefix tokens
|
||||
prefix_attn_mask = einops.repeat(prefix_mask, "b p -> b s p", s=suffix_tokens.shape[1])
|
||||
# `combined_mask` is shape (b, suffix_len, prefix_len + suffix_len) indicating how the suffix tokens (which
|
||||
# generate the queries) can attend to the full prefix + suffix sequence (which generates the keys and values)
|
||||
full_attn_mask = jnp.concatenate([prefix_attn_mask, suffix_attn_mask], axis=-1)
|
||||
assert full_attn_mask.shape == (
|
||||
batch_size,
|
||||
suffix_tokens.shape[1],
|
||||
prefix_tokens.shape[1] + suffix_tokens.shape[1],
|
||||
)
|
||||
# `positions` is shape (b, suffix_len) indicating the positions of the suffix tokens
|
||||
positions = jnp.sum(prefix_mask, axis=-1)[:, None] + jnp.cumsum(suffix_mask, axis=-1) - 1
|
||||
|
||||
(prefix_out, suffix_out), _ = self.PaliGemma.llm(
|
||||
[None, suffix_tokens], mask=full_attn_mask, positions=positions, kv_cache=kv_cache
|
||||
)
|
||||
assert prefix_out is None
|
||||
v_t = self.action_out_proj(suffix_out[:, -self.action_horizon :])
|
||||
|
||||
return x_t + dt * v_t, time + dt
|
||||
|
||||
def cond(carry):
|
||||
x_t, time = carry
|
||||
# robust to floating-point error
|
||||
return time >= -dt / 2
|
||||
|
||||
x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0))
|
||||
return x_0
|
||||
295
src/openpi/models/pi0_fast.py
Normal file
295
src/openpi/models/pi0_fast.py
Normal file
@@ -0,0 +1,295 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
import einops
|
||||
import flax.nnx as nnx
|
||||
import flax.nnx.bridge as nnx_bridge
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi.models import model as _model
|
||||
import openpi.models.gemma_fast as _gemma
|
||||
import openpi.models.siglip as _siglip
|
||||
from openpi.shared import array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
PALIGEMMA_EOS_TOKEN = 1
|
||||
|
||||
|
||||
def make_attn_mask(input_mask, mask_ar):
|
||||
"""Adapted from big_vision.
|
||||
|
||||
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
|
||||
smaller or equal to theirs. This way `mask_ar` bool[?B, N] can be used to
|
||||
setup several types of attention, for example:
|
||||
|
||||
[[1 1 1 1 1 1]]: pure causal attention.
|
||||
|
||||
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
|
||||
themselves and the last 3 tokens have a causal attention. The first
|
||||
entry could also be a 1 without changing behaviour.
|
||||
|
||||
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
|
||||
block can attend all previous blocks and all tokens on the same block.
|
||||
|
||||
Args:
|
||||
input_mask: bool[B, N] true if its part of the input, false if padding.
|
||||
mask_ar: bool[?B, N] mask that's true where previous tokens cannot depend on
|
||||
it and false where it shares the same attention mask as the previous token.
|
||||
"""
|
||||
mask_ar = jnp.broadcast_to(mask_ar, input_mask.shape)
|
||||
cumsum = jnp.cumsum(mask_ar, axis=1)
|
||||
attn_mask = cumsum[:, None, :] <= cumsum[:, :, None]
|
||||
valid_mask = input_mask[:, None, :] * input_mask[:, :, None]
|
||||
return jnp.logical_and(attn_mask, valid_mask)
|
||||
|
||||
|
||||
@jax.vmap
|
||||
def left_to_right_align(x, input_mask, attn_mask):
|
||||
"""Converts input from left-align to right-aligned."""
|
||||
# Due to vmap, this is operating in a single example (not batch level).
|
||||
assert x.ndim == 2
|
||||
assert input_mask.ndim == 1
|
||||
assert attn_mask.ndim == 2
|
||||
assert x.shape[0] == input_mask.shape[0]
|
||||
assert attn_mask.shape[0] == attn_mask.shape[1], attn_mask.shape
|
||||
seqlen = jnp.max(input_mask * jnp.arange(input_mask.shape[0])) + 1
|
||||
x = jnp.roll(x, -seqlen, axis=0)
|
||||
input_mask = jnp.roll(input_mask, -seqlen, axis=0)
|
||||
attn_mask = jnp.roll(attn_mask, -seqlen, axis=(0, 1))
|
||||
return x, input_mask, attn_mask
|
||||
|
||||
|
||||
def put_along_last_axis(arr, indices, values):
|
||||
"""Like np.put_along_axis(..., axis=-1), since jax is missing it."""
|
||||
assert arr.ndim == indices.ndim == values.ndim, (arr.ndim, indices.ndim, values.ndim)
|
||||
onehot = jax.nn.one_hot(indices, arr.shape[-1], dtype=values.dtype)
|
||||
put_mask = jnp.einsum("...i,...in->...n", jnp.ones(values.shape, jnp.int32), onehot)
|
||||
put_values = jnp.einsum("...i,...in->...n", values, onehot)
|
||||
return jnp.where(put_mask, put_values, arr)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pi0FASTConfig(_model.BaseModelConfig):
|
||||
dtype: str = "bfloat16"
|
||||
paligemma_variant: _gemma.Variant = "gemma_2b"
|
||||
|
||||
# Set the model specific defaults.
|
||||
action_dim: int = 32
|
||||
action_horizon: int = 32
|
||||
max_token_len: int = 250
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_type(self) -> _model.ModelType:
|
||||
return _model.ModelType.PI0_FAST
|
||||
|
||||
@override
|
||||
def create(self, rng: at.KeyArrayLike) -> "Pi0FAST":
|
||||
return Pi0FAST(self, rngs=nnx.Rngs(rng))
|
||||
|
||||
@override
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[_model.Observation, _model.Actions]:
|
||||
image_spec = jax.ShapeDtypeStruct([batch_size, *_model.IMAGE_RESOLUTION, 3], jnp.float32)
|
||||
image_mask_spec = jax.ShapeDtypeStruct([batch_size], jnp.bool_)
|
||||
|
||||
with at.disable_typechecking():
|
||||
observation_spec = _model.Observation(
|
||||
images={
|
||||
"base_0_rgb": image_spec,
|
||||
"base_1_rgb": image_spec,
|
||||
"wrist_0_rgb": image_spec,
|
||||
},
|
||||
image_masks={
|
||||
"base_0_rgb": image_mask_spec,
|
||||
"base_1_rgb": image_mask_spec,
|
||||
"wrist_0_rgb": image_mask_spec,
|
||||
},
|
||||
state=jax.ShapeDtypeStruct([batch_size, self.action_dim], jnp.float32),
|
||||
tokenized_prompt=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
tokenized_prompt_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], bool),
|
||||
token_ar_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.int32),
|
||||
token_loss_mask=jax.ShapeDtypeStruct([batch_size, self.max_token_len], jnp.bool_),
|
||||
)
|
||||
action_spec = jax.ShapeDtypeStruct([batch_size, self.action_horizon, self.action_dim], jnp.float32)
|
||||
|
||||
return observation_spec, action_spec
|
||||
|
||||
|
||||
class Pi0FAST(_model.BaseModel):
|
||||
def __init__(self, config: Pi0FASTConfig, rngs: nnx.Rngs):
|
||||
super().__init__(config.action_dim, config.action_horizon, config.max_token_len)
|
||||
paligemma_config = _gemma.get_config(config.paligemma_variant)
|
||||
# TODO: rewrite gemma in NNX. For now, use bridge.
|
||||
llm = nnx_bridge.ToNNX(
|
||||
_gemma.Module(
|
||||
**paligemma_config,
|
||||
embed_dtype=config.dtype,
|
||||
cache_dtype=config.dtype,
|
||||
)
|
||||
)
|
||||
llm.lazy_init(rngs=rngs, method="init")
|
||||
img = nnx_bridge.ToNNX(
|
||||
_siglip.Module(
|
||||
num_classes=paligemma_config.width,
|
||||
variant="So400m/14",
|
||||
pool_type="none",
|
||||
scan=True,
|
||||
dtype_mm=config.dtype,
|
||||
)
|
||||
)
|
||||
img.lazy_init(next(iter(config.fake_obs().images.values())), train=False, rngs=rngs)
|
||||
self.PaliGemma = nnx.Dict(llm=llm, img=img)
|
||||
|
||||
@at.typecheck
|
||||
def embed_inputs(
|
||||
self, obs: _model.Observation
|
||||
) -> tuple[at.Float[at.Array, "b s emb"], at.Bool[at.Array, "b s"], at.Int[at.Array, "b s"]]:
|
||||
input_mask = []
|
||||
ar_mask = []
|
||||
token_embeddings = []
|
||||
# embed images
|
||||
for name in obs.images:
|
||||
image_token_embeddings, _ = self.PaliGemma.img(obs.images[name], train=False)
|
||||
|
||||
token_embeddings.append(image_token_embeddings)
|
||||
input_mask.append(
|
||||
einops.repeat(
|
||||
obs.image_masks[name],
|
||||
"b -> b s",
|
||||
s=image_token_embeddings.shape[1],
|
||||
)
|
||||
)
|
||||
# image tokens attend to each other --> AR mask = 0
|
||||
ar_mask.append(0 * input_mask[-1])
|
||||
|
||||
# add tokenized inputs
|
||||
assert obs.tokenized_prompt is not None, "Tokenized prompt is required"
|
||||
assert obs.tokenized_prompt_mask is not None, "Tokenized prompt mask is required"
|
||||
assert obs.token_ar_mask is not None, "Token auto-regressive mask is required"
|
||||
tokenized_inputs_embeddings = self.PaliGemma.llm(obs.tokenized_prompt, embed_only=True)
|
||||
token_embeddings.append(tokenized_inputs_embeddings)
|
||||
input_mask.append(obs.tokenized_prompt_mask)
|
||||
ar_mask.append(obs.token_ar_mask)
|
||||
|
||||
# return embeddings, input mask, and ar mask
|
||||
return (
|
||||
jnp.concatenate(token_embeddings, axis=1),
|
||||
jnp.concatenate(input_mask, axis=1),
|
||||
jnp.concatenate(ar_mask, axis=1),
|
||||
)
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, *, train: bool = False
|
||||
) -> at.Float[at.Array, "*b ah"]:
|
||||
observation = _model.preprocess_observation(
|
||||
rng, observation, train=train, image_keys=list(observation.images.keys())
|
||||
)
|
||||
|
||||
# Compute inputs: one big forward pass of prefix + suffix at once
|
||||
input_token_embeddings, input_mask, ar_mask = self.embed_inputs(observation)
|
||||
attn_mask = make_attn_mask(input_mask, ar_mask)
|
||||
|
||||
# Compute one-hot targets: we predict *next* token, so shift the input tokens by one.
|
||||
targets = jax.nn.one_hot(
|
||||
observation.tokenized_prompt[:, 1:],
|
||||
self.PaliGemma.llm.module.vocab_size,
|
||||
)
|
||||
|
||||
# Each input predicts *next* token, so we don't input the last token.
|
||||
pre_logits, _, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=input_token_embeddings[:, :-1],
|
||||
mask=attn_mask[:, :-1, :-1],
|
||||
return_prelogits=True,
|
||||
)
|
||||
|
||||
# Only decode logits for the target tokens to save memory
|
||||
# (decoding matmul is large because it is a seq_len x vocab_size dense layer).
|
||||
logits, _ = self.PaliGemma.llm(
|
||||
pre_logits=pre_logits[:, -targets.shape[1] :],
|
||||
)
|
||||
logp = jax.nn.log_softmax(logits, axis=-1)
|
||||
|
||||
# Compute CE loss on token targets
|
||||
assert observation.token_loss_mask is not None, "Token loss mask is required"
|
||||
loss_mask = observation.token_loss_mask[:, 1:]
|
||||
token_pplx = jnp.sum(targets * logp, axis=-1)
|
||||
return -jnp.sum(token_pplx * loss_mask, axis=-1) / jnp.clip(jnp.sum(loss_mask, -1), 1)
|
||||
|
||||
@override
|
||||
def sample_actions(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: _model.Observation,
|
||||
*,
|
||||
max_decoding_steps: int | at.Int[at.Array, ""] = 256,
|
||||
temperature: float = 0.0,
|
||||
) -> _model.Actions:
|
||||
# TODO: this is a hack to get the image keys.
|
||||
observation = _model.preprocess_observation(
|
||||
None, observation, train=False, image_keys=list(observation.images.keys())
|
||||
)
|
||||
|
||||
# embed inputs
|
||||
prefix_token_embeddings, prefix_mask, prefix_ar_mask = self.embed_inputs(observation)
|
||||
prefix_attn_mask = make_attn_mask(prefix_mask, prefix_ar_mask)
|
||||
|
||||
# left to right align all input token sequences
|
||||
prefix_token_embeddings, prefix_mask, prefix_attn_mask = left_to_right_align(
|
||||
prefix_token_embeddings, prefix_mask, prefix_attn_mask
|
||||
)
|
||||
prefill_size = prefix_token_embeddings.shape[1]
|
||||
prefill_len = jnp.sum(prefix_mask, axis=-1)
|
||||
prefix_start = prefill_size - prefill_len
|
||||
|
||||
# first fill KV cache with a forward pass of the prefix
|
||||
# pad attention mask to set the size of the KV cache (prefill_size + max_decoding_steps)
|
||||
prefix_attn_mask = jnp.pad(prefix_attn_mask, ((0, 0), (0, 0), (0, max_decoding_steps)))
|
||||
prefix_positions = jnp.cumsum(prefix_mask, axis=-1) - 1
|
||||
prefix_logits, kv_cache, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=prefix_token_embeddings, mask=prefix_attn_mask, positions=prefix_positions, decode=True
|
||||
)
|
||||
|
||||
# prepare decoding -- final logit decodes the first token
|
||||
last_logit = prefix_logits[:, -1:]
|
||||
output_tokens = jnp.zeros((last_logit.shape[0], max_decoding_steps))
|
||||
|
||||
def step(carry):
|
||||
last_logit, output_tokens, cache, _, step = carry
|
||||
|
||||
# Sample token from last logit
|
||||
if temperature > 0.0:
|
||||
last_logit = last_logit / temperature
|
||||
token = jax.random.categorical(rng, last_logit, axis=-1)
|
||||
else:
|
||||
token = jnp.argmax(last_logit, axis=-1)
|
||||
output_tokens = put_along_last_axis(output_tokens, jnp.broadcast_to(step, (token.shape[0], 1)), token)
|
||||
|
||||
# Check for early stopping --> stop if all batch elements have EOS token
|
||||
has_eos = jnp.any(token == PALIGEMMA_EOS_TOKEN, axis=-1)
|
||||
all_eos = jnp.all(has_eos)
|
||||
|
||||
# Decode one step
|
||||
token_embedding = self.PaliGemma.llm(token, embed_only=True)
|
||||
positions = prefill_len[:, None] + step + 1
|
||||
mask = jnp.logical_and(
|
||||
jnp.arange(prefill_size + max_decoding_steps)[None, None, :] >= prefix_start[:, None, None],
|
||||
jnp.arange(prefill_size + max_decoding_steps)[None, None, :]
|
||||
< (jnp.broadcast_to(prefill_size + step + 1, (prefix_start.shape[0], 1, 1))),
|
||||
)
|
||||
last_logit, kv_cache, _ = self.PaliGemma.llm(
|
||||
embedded_prefix=token_embedding, mask=mask, positions=positions, decode=True, kv_cache=cache
|
||||
)
|
||||
|
||||
return last_logit, output_tokens, kv_cache, all_eos, step + 1
|
||||
|
||||
def cond(carry):
|
||||
_, _, _, all_eos, step = carry
|
||||
return (~all_eos) & (step < max_decoding_steps)
|
||||
|
||||
# Use lax.while_loop so we can jit the full decoding loop.
|
||||
_, output_tokens, _, _, _ = jax.lax.while_loop(cond, step, (last_logit, output_tokens, kv_cache, False, 0))
|
||||
return output_tokens
|
||||
46
src/openpi/models/pi0_test.py
Normal file
46
src/openpi/models/pi0_test.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import flax.nnx as nnx
|
||||
import jax
|
||||
|
||||
import openpi.models.pi0 as _pi0
|
||||
|
||||
|
||||
def _get_frozen_state(config: _pi0.Pi0Config) -> nnx.State:
|
||||
abstract_model = nnx.eval_shape(config.create, jax.random.key(0))
|
||||
|
||||
freeze_filter = config.get_freeze_filter()
|
||||
return nnx.state(abstract_model, nnx.All(nnx.Param, freeze_filter)).flat_state()
|
||||
|
||||
|
||||
def test_pi0_full_finetune():
|
||||
config = _pi0.Pi0Config()
|
||||
state = _get_frozen_state(config)
|
||||
assert len(state) == 0
|
||||
|
||||
|
||||
def test_pi0_gemma_lora():
|
||||
config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora")
|
||||
state = _get_frozen_state(config)
|
||||
assert len(state) == 9
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
assert all("_1" not in p for p in state)
|
||||
|
||||
|
||||
def test_pi0_action_expert_lora():
|
||||
config = _pi0.Pi0Config(action_expert_variant="gemma_300m_lora")
|
||||
state = _get_frozen_state(config)
|
||||
# excluding embedder, rest of the params should be same as gemma_lora.
|
||||
assert len(state) == 8
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
# all frozen params should have _1 in their path since it's the action expert.
|
||||
assert all(any("_1" in p for p in path) for path in state)
|
||||
|
||||
|
||||
def test_pi0_all_lora():
|
||||
config = _pi0.Pi0Config(paligemma_variant="gemma_2b_lora", action_expert_variant="gemma_300m_lora")
|
||||
state = _get_frozen_state(config)
|
||||
# sum of gemma_lora and action_expert_lora's frozen params.
|
||||
assert len(state) == 17
|
||||
assert all("lora" not in p for p in state)
|
||||
assert all("llm" in p for p in state)
|
||||
373
src/openpi/models/siglip.py
Normal file
373
src/openpi/models/siglip.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A refactored and simplified ViT adoptation for Pi, taken from big_vision."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
|
||||
def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32):
|
||||
"""Follows the MoCo v3 logic."""
|
||||
y, x = jnp.mgrid[:h, :w]
|
||||
|
||||
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
|
||||
omega = jnp.arange(width // 4) / (width // 4 - 1)
|
||||
omega = 1.0 / (temperature**omega)
|
||||
y = jnp.einsum("m,d->md", y.flatten(), omega)
|
||||
x = jnp.einsum("m,d->md", x.flatten(), omega)
|
||||
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
|
||||
return jnp.asarray(pe, dtype)[None, :, :]
|
||||
|
||||
|
||||
def get_posemb(self, typ, seqshape, width, name, dtype=jnp.float32):
|
||||
if typ == "learn":
|
||||
return self.param(
|
||||
name,
|
||||
nn.initializers.normal(stddev=1 / np.sqrt(width)),
|
||||
(1, np.prod(seqshape), width),
|
||||
dtype,
|
||||
)
|
||||
if typ == "sincos2d":
|
||||
return posemb_sincos_2d(*seqshape, width, dtype=dtype)
|
||||
raise ValueError(f"Unknown posemb type: {typ}")
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
inits = {
|
||||
"kernel_init": nn.initializers.xavier_uniform(),
|
||||
"bias_init": nn.initializers.normal(stddev=1e-6),
|
||||
}
|
||||
|
||||
_, _, d = x.shape # n,l,d
|
||||
x = nn.Dense(self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout)(x, deterministic)
|
||||
return nn.Dense(d, dtype=self.dtype_mm, **inits)(x)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Single transformer encoder block (MHSA + MLP)."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
x = sharding.activation_sharding_constraint(x)
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["sa"] = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
deterministic=deterministic,
|
||||
dtype=self.dtype_mm,
|
||||
)(y, y)
|
||||
y = sharding.activation_sharding_constraint(y)
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+sa"] = x + y
|
||||
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
y = out["mlp"] = MlpBlock(
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout=self.dropout,
|
||||
dtype_mm=self.dtype_mm,
|
||||
)(y, deterministic)
|
||||
y = sharding.activation_sharding_constraint(y)
|
||||
y = nn.Dropout(rate=self.dropout)(y, deterministic)
|
||||
x = out["+mlp"] = x + y
|
||||
x = sharding.activation_sharding_constraint(x)
|
||||
return x, out
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation."""
|
||||
|
||||
depth: int
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dropout: float = 0.0
|
||||
scan: bool = False
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, deterministic=True): # noqa: FBT002
|
||||
out = {}
|
||||
|
||||
if self.scan:
|
||||
block = nn.remat(
|
||||
Encoder1DBlock,
|
||||
prevent_cse=False,
|
||||
static_argnums=(2,), # 0=self, 2=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
|
||||
)
|
||||
x, scan_out = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.depth,
|
||||
)(
|
||||
name="encoderblock",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)(x, deterministic)
|
||||
for lyr in range(self.depth):
|
||||
out[f"block{lyr:02d}"] = jax.tree.map(lambda o, lyr=lyr: o[lyr], scan_out)
|
||||
else:
|
||||
# Input Encoder
|
||||
for lyr in range(self.depth):
|
||||
block_cur = Encoder1DBlock(
|
||||
name=f"encoderblock_{lyr}",
|
||||
dtype_mm=self.dtype_mm,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
x, out[f"block{lyr:02d}"] = block_cur(x, deterministic)
|
||||
out["pre_ln"] = x # Alias for last block, but without the number in it.
|
||||
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype_mm)(x), out
|
||||
|
||||
|
||||
class MAPHead(nn.Module):
|
||||
"""Multihead Attention Pooling."""
|
||||
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
n, _, d = x.shape # n,l,d
|
||||
probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, d), x.dtype)
|
||||
probe = jnp.tile(probe, [n, 1, 1])
|
||||
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
num_heads=self.num_heads,
|
||||
dtype=self.dtype_mm,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
)(probe, x)
|
||||
|
||||
y = nn.LayerNorm(dtype=self.dtype_mm)(x)
|
||||
x = x + MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype_mm)(y)
|
||||
return x[:, 0]
|
||||
|
||||
|
||||
class _Module(nn.Module):
|
||||
"""ViT model."""
|
||||
|
||||
num_classes: int | None = None
|
||||
patch_size: Sequence[int] = (16, 16)
|
||||
width: int = 768
|
||||
depth: int = 12
|
||||
mlp_dim: int | None = None # Defaults to 4x input dim
|
||||
num_heads: int = 12
|
||||
posemb: str = "learn" # Can also be "sincos2d"
|
||||
rep_size: int | bool = False
|
||||
dropout: float = 0.0
|
||||
pool_type: str = "gap" # Can also be "map" or "tok"
|
||||
head_zeroinit: bool = True
|
||||
scan: bool = False
|
||||
# or "dots_with_no_batch_dims_saveable" for more speed (memory costly)
|
||||
remat_policy: str = "nothing_saveable"
|
||||
dtype_mm: str = "float32"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, image, *, train=False):
|
||||
out = {}
|
||||
|
||||
# Kevin edit: do patch extraction and posemb in float32,
|
||||
# because I feel like it's a bit safer.
|
||||
image = jnp.asarray(image, jnp.float32)
|
||||
|
||||
# Patch extraction
|
||||
x = out["stem"] = nn.Conv(
|
||||
self.width,
|
||||
self.patch_size,
|
||||
strides=self.patch_size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
dtype=jnp.float32,
|
||||
)(image)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# Add posemb before adding extra token.
|
||||
x = out["with_posemb"] = x + get_posemb(self, self.posemb, (h, w), c, "pos_embedding", jnp.float32)
|
||||
|
||||
if self.pool_type == "tok":
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c), x.dtype)
|
||||
x = jnp.concatenate([jnp.tile(cls, [n, 1, 1]), x], axis=1)
|
||||
|
||||
n, _, c = x.shape # n,l,d
|
||||
x = nn.Dropout(rate=self.dropout)(x, not train)
|
||||
|
||||
# Kevin edit: now cast back to dtype_mm (potentially half precision)
|
||||
x = x.astype(self.dtype_mm)
|
||||
|
||||
x, out["encoder"] = Encoder(
|
||||
depth=self.depth,
|
||||
mlp_dim=self.mlp_dim,
|
||||
num_heads=self.num_heads,
|
||||
dropout=self.dropout,
|
||||
scan=self.scan,
|
||||
remat_policy=self.remat_policy,
|
||||
dtype_mm=self.dtype_mm,
|
||||
name="Transformer",
|
||||
)(x, deterministic=not train)
|
||||
encoded = out["encoded"] = x
|
||||
|
||||
if self.pool_type == "map":
|
||||
x = out["head_input"] = MAPHead(
|
||||
num_heads=self.num_heads,
|
||||
mlp_dim=self.mlp_dim,
|
||||
dtype=self.dtype_mm,
|
||||
)(x)
|
||||
elif self.pool_type == "gap":
|
||||
x = out["head_input"] = jnp.mean(x, axis=1)
|
||||
elif self.pool_type == "0":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
elif self.pool_type == "tok":
|
||||
x = out["head_input"] = x[:, 0]
|
||||
encoded = encoded[:, 1:]
|
||||
elif self.pool_type == "none":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Unknown pool type: '{self.pool_type}'")
|
||||
|
||||
x_2d = jnp.reshape(encoded, [n, h, w, -1])
|
||||
|
||||
if self.rep_size:
|
||||
rep_size = self.width if self.rep_size is True else self.rep_size
|
||||
hid = nn.Dense(rep_size, dtype=self.dtype_mm, name="pre_logits")
|
||||
# NOTE: In the past we did not include tanh in pre_logits.
|
||||
# For few-shot, it should not matter much, as it whitens anyways.
|
||||
x_2d = nn.tanh(hid(x_2d))
|
||||
x = nn.tanh(hid(x))
|
||||
|
||||
out["pre_logits_2d"] = x_2d
|
||||
out["pre_logits"] = x
|
||||
|
||||
if self.num_classes:
|
||||
kw = {"kernel_init": nn.initializers.zeros} if self.head_zeroinit else {}
|
||||
head = nn.Dense(self.num_classes, dtype=self.dtype_mm, name="head", **kw)
|
||||
x_2d = out["logits_2d"] = head(x_2d)
|
||||
x = out["logits"] = head(x)
|
||||
|
||||
return x, out
|
||||
|
||||
|
||||
def Module(num_classes=None, *, variant=None, **kw): # pylint: disable=invalid-name # noqa: N802
|
||||
"""Factory function, because linen really don't like what I'm doing!"""
|
||||
return _Module(num_classes, **{**decode_variant(variant), **kw})
|
||||
|
||||
|
||||
def decode_variant(variant):
|
||||
"""Converts a string like "B" or "B/32" into a params dict."""
|
||||
if variant is None:
|
||||
return {}
|
||||
|
||||
v, patch = variant, {}
|
||||
if "/" in variant:
|
||||
v, patch = variant.split("/")
|
||||
patch = {"patch_size": (int(patch), int(patch))}
|
||||
|
||||
return {
|
||||
# pylint:disable=line-too-long
|
||||
# Reference: Table 2 of https://arxiv.org/abs/2106.04560.
|
||||
"width": {
|
||||
"mu": 32,
|
||||
"Ti": 192,
|
||||
"S": 384,
|
||||
"M": 512,
|
||||
"B": 768,
|
||||
"L": 1024,
|
||||
"So400m": 1152,
|
||||
"H": 1280,
|
||||
"g": 1408,
|
||||
"g-opt": 1536,
|
||||
"G": 1664,
|
||||
"G-opt": 1536,
|
||||
"e": 1792,
|
||||
}[v],
|
||||
"depth": {
|
||||
"mu": 1,
|
||||
"Ti": 12,
|
||||
"S": 12,
|
||||
"M": 12,
|
||||
"B": 12,
|
||||
"L": 24,
|
||||
"So400m": 27,
|
||||
"H": 32,
|
||||
"g": 40,
|
||||
"g-opt": 40,
|
||||
"G": 48,
|
||||
"G-opt": 48,
|
||||
"e": 56,
|
||||
}[v],
|
||||
"mlp_dim": {
|
||||
"mu": 128,
|
||||
"Ti": 768,
|
||||
"S": 1536,
|
||||
"M": 2048,
|
||||
"B": 3072,
|
||||
"L": 4096,
|
||||
"So400m": 4304,
|
||||
"H": 5120,
|
||||
"g": 6144,
|
||||
"g-opt": 6144,
|
||||
"G": 8192,
|
||||
"G-opt": 8192,
|
||||
"e": 15360,
|
||||
}[v],
|
||||
"num_heads": {
|
||||
"mu": 2,
|
||||
"Ti": 3,
|
||||
"S": 6,
|
||||
"M": 8,
|
||||
"B": 12,
|
||||
"L": 16,
|
||||
"So400m": 16,
|
||||
"H": 16,
|
||||
"g": 16,
|
||||
"g-opt": 16,
|
||||
"G": 16,
|
||||
"G-opt": 16,
|
||||
"e": 16,
|
||||
}[v],
|
||||
# pylint:enable=line-too-long
|
||||
**patch,
|
||||
}
|
||||
127
src/openpi/models/tokenizer.py
Normal file
127
src/openpi/models/tokenizer.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece
|
||||
from transformers import AutoProcessor
|
||||
|
||||
import openpi.shared.download as download
|
||||
|
||||
|
||||
class PaligemmaTokenizer:
|
||||
def __init__(self, max_len: int = 48):
|
||||
self._max_len = max_len
|
||||
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
def tokenize(self, prompt: str) -> tuple[np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.strip().replace("_", " ").replace("\n", " ")
|
||||
# tokenize "\n" separately as the "start of answer" token
|
||||
tokens = self._tokenizer.encode(cleaned_text, add_bos=True) + self._tokenizer.encode("\n")
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
mask = [True] * tokens_len + padding
|
||||
tokens = tokens + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
mask = [True] * self._max_len
|
||||
|
||||
return np.asarray(tokens), np.asarray(mask)
|
||||
|
||||
|
||||
class FASTTokenizer:
|
||||
def __init__(self, max_len: int = 256, fast_tokenizer_path: str = "physical-intelligence/fast"):
|
||||
self._max_len = max_len
|
||||
|
||||
# Download base PaliGemma tokenizer
|
||||
path = download.maybe_download("gs://big_vision/paligemma_tokenizer.model", gs={"token": "anon"})
|
||||
with path.open("rb") as f:
|
||||
self._paligemma_tokenizer = sentencepiece.SentencePieceProcessor(model_proto=f.read())
|
||||
|
||||
# Instantiate FAST tokenizer
|
||||
self._fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True)
|
||||
self._fast_skip_tokens = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens
|
||||
|
||||
def tokenize(
|
||||
self, prompt: str, state: np.ndarray, actions: np.ndarray | None
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
cleaned_text = prompt.lower().strip().replace("_", " ")
|
||||
|
||||
# Convention: state gets discretized into 256 discrete bins (assumed range after normalization: [-1, 1])
|
||||
discretized_state = np.digitize(state, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
||||
|
||||
# Convention: prefix includes prompt and string-representation of state, followed by ';'
|
||||
state_str = " ".join(map(str, discretized_state))
|
||||
prefix = f"Task: {cleaned_text}, State: {state_str};\n"
|
||||
prefix_tokens = self._paligemma_tokenizer.encode(prefix, add_bos=True)
|
||||
|
||||
if actions is not None:
|
||||
# Tokenize actions with FAST tokenizer --> map to last tokens in PaliGemma vocab
|
||||
action_tokens = self._fast_tokenizer(actions[None])[0]
|
||||
action_tokens_in_pg = self._act_tokens_to_paligemma_tokens(action_tokens)
|
||||
|
||||
# Convention: postfix contains 'Action:' followed by FAST tokens, followed by '|'
|
||||
postfix_tokens = (
|
||||
self._paligemma_tokenizer.encode("Action: ")
|
||||
+ action_tokens_in_pg.tolist()
|
||||
+ self._paligemma_tokenizer.encode("|")
|
||||
)
|
||||
else:
|
||||
postfix_tokens = []
|
||||
|
||||
# Create output token sequence & masks
|
||||
# AR mask is 0 on prefix (bidirectional attention) and 1 on postfix (causal attention to all previous tokens)
|
||||
tokens = prefix_tokens + postfix_tokens
|
||||
token_mask = [True] * len(tokens)
|
||||
ar_mask = [0] * len(prefix_tokens) + [1] * len(postfix_tokens)
|
||||
loss_mask = [False] * len(prefix_tokens) + [True] * len(postfix_tokens) # Loss on postfix only
|
||||
|
||||
# Pad tokens to max length
|
||||
tokens_len = len(tokens)
|
||||
if tokens_len < self._max_len:
|
||||
padding = [False] * (self._max_len - tokens_len)
|
||||
tokens = tokens + padding
|
||||
token_mask = token_mask + padding
|
||||
ar_mask = ar_mask + padding
|
||||
loss_mask = loss_mask + padding
|
||||
else:
|
||||
if len(tokens) > self._max_len:
|
||||
logging.warning(
|
||||
f"Token length ({len(tokens)}) exceeds max length ({self._max_len}), truncating. "
|
||||
"Consider increasing the `max_token_len` in your model config if this happens frequently."
|
||||
)
|
||||
tokens = tokens[: self._max_len]
|
||||
token_mask = token_mask[: self._max_len]
|
||||
ar_mask = ar_mask[: self._max_len]
|
||||
loss_mask = loss_mask[: self._max_len]
|
||||
|
||||
return np.asarray(tokens), np.asarray(token_mask), np.asarray(ar_mask), np.asarray(loss_mask)
|
||||
|
||||
def extract_actions(self, tokens: np.ndarray, action_horizon: int, action_dim: int) -> np.ndarray:
|
||||
# Decode predicted output tokens
|
||||
decoded_tokens = self._paligemma_tokenizer.decode(tokens.tolist())
|
||||
|
||||
# Extract actions from FAST model outputs
|
||||
if "Action: " not in decoded_tokens:
|
||||
return np.zeros((action_horizon, action_dim), dtype=np.float32)
|
||||
|
||||
# Extract actions from decoded tokens
|
||||
raw_action_tokens = np.array(
|
||||
self._paligemma_tokenizer.encode(decoded_tokens.split("Action: ")[1].split("|")[0].strip())
|
||||
)
|
||||
action_tokens = self._act_tokens_to_paligemma_tokens(raw_action_tokens)
|
||||
return self._fast_tokenizer.decode(
|
||||
[action_tokens.tolist()], time_horizon=action_horizon, action_dim=action_dim
|
||||
)[0]
|
||||
|
||||
def _act_tokens_to_paligemma_tokens(self, tokens: np.ndarray | list[int]) -> np.ndarray:
|
||||
if isinstance(tokens, list):
|
||||
tokens = np.array(tokens)
|
||||
return self._paligemma_tokenizer.vocab_size() - 1 - self._fast_skip_tokens - tokens
|
||||
27
src/openpi/models/tokenizer_test.py
Normal file
27
src/openpi/models/tokenizer_test.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
|
||||
from openpi.models import tokenizer as _tokenizer
|
||||
|
||||
|
||||
def test_tokenize():
|
||||
tokenizer = _tokenizer.PaligemmaTokenizer(max_len=10)
|
||||
tokens, masks = tokenizer.tokenize("Hello, world!")
|
||||
|
||||
assert tokens.shape == (10,)
|
||||
assert masks.shape == (10,)
|
||||
|
||||
|
||||
def test_fast_tokenizer():
|
||||
prompt = "Hello, world!"
|
||||
state = np.random.rand(5).astype(np.float32)
|
||||
action = np.random.rand(3, 2).astype(np.float32)
|
||||
tokenizer = _tokenizer.FASTTokenizer(max_len=256)
|
||||
tokens, token_masks, ar_masks, loss_masks = tokenizer.tokenize(prompt, state, action)
|
||||
|
||||
assert tokens.shape == (256,)
|
||||
assert token_masks.shape == (256,)
|
||||
assert ar_masks.shape == (256,)
|
||||
assert loss_masks.shape == (256,)
|
||||
|
||||
act = tokenizer.extract_actions(tokens, 3, 2)
|
||||
assert act.shape == (3, 2)
|
||||
307
src/openpi/models/vit.py
Normal file
307
src/openpi/models/vit.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright 2024 Google LLC.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""ViT implementation adapted from https://github.com/google-research/vision_transformer/blob/main/vit_jax/models_vit.py."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from openpi.models import resnet as models_resnet
|
||||
|
||||
Array = Any
|
||||
PRNGKey = Any
|
||||
Shape = tuple[int]
|
||||
Dtype = Any
|
||||
|
||||
|
||||
class IdentityLayer(nn.Module):
|
||||
"""Identity layer, convenient for giving a name to an array."""
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class AddPositionEmbs(nn.Module):
|
||||
"""Adds learned positional embeddings to the inputs.
|
||||
|
||||
Attributes:
|
||||
posemb_init: positional embedding initializer.
|
||||
"""
|
||||
|
||||
posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]
|
||||
param_dtype: Dtype = jnp.float32
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs):
|
||||
"""Applies the AddPositionEmbs module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
|
||||
Returns:
|
||||
Output tensor with shape `(bs, timesteps, in_dim)`.
|
||||
"""
|
||||
# inputs.shape is (batch_size, seq_len, emb_dim).
|
||||
assert inputs.ndim == 3, f"Number of dimensions should be 3, but it is: {inputs.ndim}"
|
||||
pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
|
||||
pe = self.param("pos_embedding", self.posemb_init, pos_emb_shape, self.param_dtype)
|
||||
return inputs + pe
|
||||
|
||||
|
||||
class MlpBlock(nn.Module):
|
||||
"""Transformer MLP / feed-forward block."""
|
||||
|
||||
mlp_dim: int
|
||||
dtype: Dtype = jnp.float32
|
||||
param_dtype: Dtype = jnp.float32
|
||||
out_dim: int | None = None
|
||||
dropout_rate: float = 0.1
|
||||
kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.xavier_uniform()
|
||||
bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.normal(stddev=1e-6)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, deterministic):
|
||||
"""Applies Transformer MlpBlock module."""
|
||||
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
|
||||
x = nn.Dense(
|
||||
features=self.mlp_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
inputs
|
||||
)
|
||||
x = nn.gelu(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
output = nn.Dense(
|
||||
features=actual_out_dim,
|
||||
dtype=self.dtype,
|
||||
param_dtype=self.param_dtype,
|
||||
kernel_init=self.kernel_init,
|
||||
bias_init=self.bias_init,
|
||||
)( # pytype: disable=wrong-arg-types
|
||||
x
|
||||
)
|
||||
return nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
|
||||
|
||||
|
||||
class Encoder1DBlock(nn.Module):
|
||||
"""Transformer encoder layer.
|
||||
|
||||
Attributes:
|
||||
inputs: input data.
|
||||
mlp_dim: dimension of the mlp on top of attention block.
|
||||
dtype: the dtype of the computation (default: float32).
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout for attention heads.
|
||||
deterministic: bool, deterministic or not (to apply dropout).
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
"""
|
||||
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dtype: Dtype = jnp.float32
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, deterministic):
|
||||
"""Applies Encoder1DBlock module.
|
||||
|
||||
Args:
|
||||
inputs: Inputs to the layer.
|
||||
deterministic: Dropout will not be applied when set to true.
|
||||
|
||||
Returns:
|
||||
output after transformer encoder block.
|
||||
"""
|
||||
|
||||
# Attention block.
|
||||
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
|
||||
x = nn.LayerNorm(dtype=self.dtype)(inputs)
|
||||
x = nn.MultiHeadDotProductAttention(
|
||||
dtype=self.dtype,
|
||||
kernel_init=nn.initializers.xavier_uniform(),
|
||||
broadcast_dropout=False,
|
||||
deterministic=deterministic,
|
||||
dropout_rate=self.attention_dropout_rate,
|
||||
num_heads=self.num_heads,
|
||||
# why isn't this true by default???
|
||||
force_fp32_for_softmax=True,
|
||||
)(x, x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
|
||||
x = x + inputs
|
||||
|
||||
# MLP block.
|
||||
y = nn.LayerNorm(dtype=self.dtype)(x)
|
||||
y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate)(
|
||||
y, deterministic=deterministic
|
||||
)
|
||||
|
||||
return x + y, None
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Transformer Model Encoder for sequence to sequence translation.
|
||||
|
||||
Attributes:
|
||||
num_layers: number of layers
|
||||
mlp_dim: dimension of the mlp on top of attention block
|
||||
num_heads: Number of heads in nn.MultiHeadDotProductAttention
|
||||
dropout_rate: dropout rate.
|
||||
attention_dropout_rate: dropout rate in self attention.
|
||||
"""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_layers: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
dropout_rate: float = 0.1
|
||||
attention_dropout_rate: float = 0.1
|
||||
add_position_embedding: bool = True
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, *, train):
|
||||
"""Applies Transformer model on the inputs.
|
||||
|
||||
Args:
|
||||
x: Inputs to the layer.
|
||||
train: Set to `True` when training.
|
||||
|
||||
Returns:
|
||||
output of a transformer encoder.
|
||||
"""
|
||||
assert x.ndim == 3 # (batch, len, emb)
|
||||
|
||||
if self.add_position_embedding:
|
||||
x = AddPositionEmbs(
|
||||
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
|
||||
name="posembed_input",
|
||||
)(x)
|
||||
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
|
||||
|
||||
x = x.astype(self.dtype)
|
||||
# Input Encoder
|
||||
block = nn.remat(Encoder1DBlock, prevent_cse=False, static_argnums=(2,))
|
||||
x, _ = nn.scan(
|
||||
block,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=nn.broadcast,
|
||||
length=self.num_layers,
|
||||
)(
|
||||
name="encoderblock",
|
||||
mlp_dim=self.mlp_dim,
|
||||
dropout_rate=self.dropout_rate,
|
||||
attention_dropout_rate=self.attention_dropout_rate,
|
||||
dtype=self.dtype,
|
||||
num_heads=self.num_heads,
|
||||
)(x, not train)
|
||||
return nn.LayerNorm(name="encoder_norm", dtype=self.dtype)(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""VisionTransformer."""
|
||||
|
||||
dtype: jax.typing.DTypeLike
|
||||
num_classes: int
|
||||
patches: Any
|
||||
transformer: Any
|
||||
hidden_size: int
|
||||
resnet: Any | None = None
|
||||
representation_size: int | None = None
|
||||
classifier: str = "token"
|
||||
head_bias_init: float = 0.0
|
||||
encoder: type[nn.Module] = Encoder
|
||||
model_name: str | None = None
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, inputs, *, train):
|
||||
x = inputs
|
||||
# (Possibly partial) ResNet root.
|
||||
if self.resnet is not None:
|
||||
width = int(64 * self.resnet.width_factor)
|
||||
|
||||
# Root block.
|
||||
x = models_resnet.StdConv(
|
||||
features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name="conv_root"
|
||||
)(x)
|
||||
x = nn.GroupNorm(name="gn_root")(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME")
|
||||
|
||||
# ResNet stages.
|
||||
if self.resnet.num_layers:
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name="block1"
|
||||
)(x)
|
||||
for i, block_size in enumerate(self.resnet.num_layers[1:], 1):
|
||||
x = models_resnet.ResNetStage(
|
||||
block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f"block{i + 1}"
|
||||
)(x)
|
||||
|
||||
n, h, w, c = x.shape
|
||||
|
||||
# We can merge s2d+emb into a single conv; it's the same.
|
||||
x = nn.Conv(
|
||||
features=self.hidden_size,
|
||||
kernel_size=self.patches.size,
|
||||
strides=self.patches.size,
|
||||
padding="VALID",
|
||||
name="embedding",
|
||||
)(x)
|
||||
|
||||
# Here, x is a grid of embeddings.
|
||||
|
||||
# (Possibly partial) Transformer.
|
||||
if self.transformer is not None:
|
||||
n, h, w, c = x.shape
|
||||
x = jnp.reshape(x, [n, h * w, c])
|
||||
|
||||
# If we want to add a class token, add it here.
|
||||
if self.classifier in ["token", "token_unpooled"]:
|
||||
cls = self.param("cls", nn.initializers.zeros, (1, 1, c))
|
||||
cls = jnp.tile(cls, [n, 1, 1])
|
||||
x = jnp.concatenate([cls, x], axis=1)
|
||||
|
||||
x = self.encoder(name="Transformer", **self.transformer, dtype=self.dtype)(x, train=train)
|
||||
|
||||
if self.classifier == "token":
|
||||
x = x[:, 0]
|
||||
elif self.classifier == "gap":
|
||||
x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2)
|
||||
elif self.classifier in ["unpooled", "token_unpooled"]:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid classifier={self.classifier}")
|
||||
|
||||
if self.representation_size is not None:
|
||||
x = nn.Dense(features=self.representation_size, name="pre_logits")(x)
|
||||
x = nn.tanh(x)
|
||||
else:
|
||||
x = IdentityLayer(name="pre_logits")(x)
|
||||
|
||||
if self.num_classes:
|
||||
x = nn.Dense(
|
||||
features=self.num_classes,
|
||||
name="head",
|
||||
kernel_init=nn.initializers.zeros,
|
||||
bias_init=nn.initializers.constant(self.head_bias_init),
|
||||
)(x)
|
||||
return x
|
||||
206
src/openpi/policies/aloha_policy.py
Normal file
206
src/openpi/policies/aloha_policy.py
Normal file
@@ -0,0 +1,206 @@
|
||||
import dataclasses
|
||||
from typing import ClassVar
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
|
||||
|
||||
def make_aloha_example() -> dict:
|
||||
"""Creates a random input example for the Aloha policy."""
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AlohaInputs(transforms.DataTransformFn):
|
||||
"""Inputs for the Aloha policy.
|
||||
|
||||
Expected inputs:
|
||||
- images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
|
||||
- state: [14]
|
||||
- actions: [action_horizon, 14]
|
||||
"""
|
||||
|
||||
# The action dimension of the model. Will be used to pad state and actions.
|
||||
action_dim: int
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
# The expected cameras names. All input cameras must be in this set. Missing cameras will be
|
||||
# replaced with black images and the corresponding `image_mask` will be set to False.
|
||||
EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist")
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
data = _decode_aloha(data, adapt_to_pi=self.adapt_to_pi)
|
||||
|
||||
# Get the state. We are padding from 14 to the model action dim.
|
||||
state = transforms.pad_to_dim(data["state"], self.action_dim)
|
||||
|
||||
in_images = data["images"]
|
||||
if set(in_images) - set(self.EXPECTED_CAMERAS):
|
||||
raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")
|
||||
|
||||
# Assume that base image always exists.
|
||||
base_image = in_images["cam_high"]
|
||||
|
||||
images = {
|
||||
"base_0_rgb": base_image,
|
||||
}
|
||||
image_masks = {
|
||||
"base_0_rgb": np.True_,
|
||||
}
|
||||
|
||||
# Add the extra images.
|
||||
extra_image_names = {
|
||||
"left_wrist_0_rgb": "cam_left_wrist",
|
||||
"right_wrist_0_rgb": "cam_right_wrist",
|
||||
}
|
||||
for dest, source in extra_image_names.items():
|
||||
if source in in_images:
|
||||
images[dest] = in_images[source]
|
||||
image_masks[dest] = np.True_
|
||||
else:
|
||||
images[dest] = np.zeros_like(base_image)
|
||||
image_masks[dest] = np.False_
|
||||
|
||||
inputs = {
|
||||
"image": images,
|
||||
"image_mask": image_masks,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
actions = np.asarray(data["actions"])
|
||||
actions = _encode_actions_inv(actions, adapt_to_pi=self.adapt_to_pi)
|
||||
inputs["actions"] = transforms.pad_to_dim(actions, self.action_dim)
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AlohaOutputs(transforms.DataTransformFn):
|
||||
"""Outputs for the Aloha policy."""
|
||||
|
||||
# If true, this will convert the joint and gripper values from the standard Aloha space to
|
||||
# the space used by the pi internal runtime which was used to train the base model.
|
||||
adapt_to_pi: bool = True
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 14 dims.
|
||||
actions = np.asarray(data["actions"][:, :14])
|
||||
return {"actions": _encode_actions(actions, adapt_to_pi=self.adapt_to_pi)}
|
||||
|
||||
|
||||
def _joint_flip_mask() -> np.ndarray:
|
||||
"""Used to convert between aloha and pi joint angles."""
|
||||
return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])
|
||||
|
||||
|
||||
def _normalize(x, min_val, max_val):
|
||||
return (x - min_val) / (max_val - min_val)
|
||||
|
||||
|
||||
def _unnormalize(x, min_val, max_val):
|
||||
return x * (max_val - min_val) + min_val
|
||||
|
||||
|
||||
def _gripper_to_angular(value):
|
||||
# Aloha transforms the gripper positions into a linear space. The following code
|
||||
# reverses this transformation to be consistent with pi0 which is pretrained in
|
||||
# angular space.
|
||||
#
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
|
||||
value = _unnormalize(value, min_val=0.01844, max_val=0.05800)
|
||||
|
||||
# This is the inverse of the angular to linear transformation inside the Interbotix code.
|
||||
def linear_to_radian(linear_position, arm_length, horn_radius):
|
||||
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
|
||||
return np.arcsin(np.clip(value, -1.0, 1.0))
|
||||
|
||||
# The constants are taken from the Interbotix code.
|
||||
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
|
||||
|
||||
# Normalize to [0, 1].
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
return _normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def _gripper_from_angular(value):
|
||||
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
|
||||
# Note that the units are still angular but the range is different.
|
||||
|
||||
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
|
||||
value = _unnormalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
# These values are coming from the Aloha code:
|
||||
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
|
||||
return _normalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
|
||||
|
||||
def _gripper_from_angular_inv(value):
|
||||
# Directly inverts the gripper_from_angular function.
|
||||
value = _unnormalize(value, min_val=-0.6213, max_val=1.4910)
|
||||
return _normalize(value, min_val=0.4, max_val=1.5)
|
||||
|
||||
|
||||
def _decode_aloha(data: dict, *, adapt_to_pi: bool = False) -> dict:
|
||||
# state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
|
||||
# dim sizes: [6, 1, 6, 1]
|
||||
state = np.asarray(data["state"])
|
||||
state = _decode_state(state, adapt_to_pi=adapt_to_pi)
|
||||
|
||||
def convert_image(img):
|
||||
img = np.asarray(img)
|
||||
# Convert to uint8 if using float images.
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
# Convert from [channel, height, width] to [height, width, channel].
|
||||
return einops.rearrange(img, "c h w -> h w c")
|
||||
|
||||
images = data["images"]
|
||||
images_dict = {name: convert_image(img) for name, img in images.items()}
|
||||
|
||||
data["images"] = images_dict
|
||||
data["state"] = state
|
||||
return data
|
||||
|
||||
|
||||
def _decode_state(state: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
state = _joint_flip_mask() * state
|
||||
# Reverse the gripper transformation that is being applied by the Aloha runtime.
|
||||
state[[6, 13]] = _gripper_to_angular(state[[6, 13]])
|
||||
return state
|
||||
|
||||
|
||||
def _encode_actions(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
# Flip the joints.
|
||||
actions = _joint_flip_mask() * actions
|
||||
actions[:, [6, 13]] = _gripper_from_angular(actions[:, [6, 13]])
|
||||
return actions
|
||||
|
||||
|
||||
def _encode_actions_inv(actions: np.ndarray, *, adapt_to_pi: bool = False) -> np.ndarray:
|
||||
if adapt_to_pi:
|
||||
actions = _joint_flip_mask() * actions
|
||||
actions[:, [6, 13]] = _gripper_from_angular_inv(actions[:, [6, 13]])
|
||||
return actions
|
||||
79
src/openpi/policies/droid_policy.py
Normal file
79
src/openpi/policies/droid_policy.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import dataclasses
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import model as _model
|
||||
|
||||
|
||||
def make_droid_example() -> dict:
|
||||
"""Creates a random input example for the Droid policy."""
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _parse_image(image) -> np.ndarray:
|
||||
image = np.asarray(image)
|
||||
if np.issubdtype(image.dtype, np.floating):
|
||||
image = (255 * image).astype(np.uint8)
|
||||
if image.shape[0] == 3:
|
||||
image = einops.rearrange(image, "c h w -> h w c")
|
||||
return image
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DroidInputs(transforms.DataTransformFn):
|
||||
# The action dimension of the model. Will be used to pad state and actions.
|
||||
action_dim: int
|
||||
|
||||
# Determines which model will be used.
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
state = np.concatenate([data["observation/joint_position"], data["observation/gripper_position"]])
|
||||
state = transforms.pad_to_dim(state, self.action_dim)
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference
|
||||
base_image = _parse_image(data["observation/exterior_image_1_left"])
|
||||
wrist_image = _parse_image(data["observation/wrist_image_left"])
|
||||
|
||||
match self.model_type:
|
||||
case _model.ModelType.PI0:
|
||||
names = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
||||
images = (base_image, wrist_image, np.zeros_like(base_image))
|
||||
image_masks = (np.True_, np.True_, np.False_)
|
||||
case _model.ModelType.PI0_FAST:
|
||||
names = ("base_0_rgb", "base_1_rgb", "wrist_0_rgb")
|
||||
# We don't mask out padding images for FAST models.
|
||||
images = (base_image, np.zeros_like(base_image), wrist_image)
|
||||
image_masks = (np.True_, np.True_, np.True_)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported model type: {self.model_type}")
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": dict(zip(names, images, strict=True)),
|
||||
"image_mask": dict(zip(names, image_masks, strict=True)),
|
||||
}
|
||||
|
||||
if "actions" in data:
|
||||
inputs["actions"] = data["actions"]
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DroidOutputs(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 8 dims.
|
||||
return {"actions": np.asarray(data["actions"][:, :8])}
|
||||
80
src/openpi/policies/libero_policy.py
Normal file
80
src/openpi/policies/libero_policy.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import dataclasses
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
|
||||
from openpi import transforms
|
||||
from openpi.models import model as _model
|
||||
|
||||
|
||||
def make_libero_example() -> dict:
|
||||
"""Creates a random input example for the Libero policy."""
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _parse_image(image) -> np.ndarray:
|
||||
image = np.asarray(image)
|
||||
if np.issubdtype(image.dtype, np.floating):
|
||||
image = (255 * image).astype(np.uint8)
|
||||
if image.shape[0] == 3:
|
||||
image = einops.rearrange(image, "c h w -> h w c")
|
||||
return image
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoInputs(transforms.DataTransformFn):
|
||||
# The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
|
||||
action_dim: int
|
||||
|
||||
# Determines which model will be used.
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
mask_padding = self.model_type == _model.ModelType.PI0 # We don't mask for pi0-FAST.
|
||||
|
||||
# Get the state. We are padding from 8 to the model action dim.
|
||||
# For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
|
||||
state = transforms.pad_to_dim(data["observation/state"], self.action_dim)
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference
|
||||
base_image = _parse_image(data["observation/image"])
|
||||
wrist_image = _parse_image(data["observation/wrist_image"])
|
||||
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
"right_wrist_0_rgb": np.False_ if mask_padding else np.True_,
|
||||
},
|
||||
}
|
||||
|
||||
# Actions are only available during training.
|
||||
if "actions" in data:
|
||||
# We are padding from 7 to the model action dim.
|
||||
# For pi0-FAST, this is a no-op (since action_dim = 7).
|
||||
actions = transforms.pad_to_dim(data["actions"], self.action_dim)
|
||||
inputs["actions"] = actions
|
||||
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LiberoOutputs(transforms.DataTransformFn):
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Only return the first 7 dims.
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
85
src/openpi/policies/policy.py
Normal file
85
src/openpi/policies/policy.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from collections.abc import Sequence
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import flax
|
||||
import flax.traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi import transforms as _transforms
|
||||
from openpi.models import model as _model
|
||||
from openpi.shared import array_typing as at
|
||||
from openpi.shared import nnx_utils
|
||||
|
||||
BasePolicy: TypeAlias = _base_policy.BasePolicy
|
||||
|
||||
|
||||
class Policy(BasePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
model: _model.BaseModel,
|
||||
*,
|
||||
rng: at.KeyArrayLike | None = None,
|
||||
transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
output_transforms: Sequence[_transforms.DataTransformFn] = (),
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
):
|
||||
self._sample_actions = nnx_utils.module_jit(model.sample_actions)
|
||||
self._input_transform = _transforms.compose(transforms)
|
||||
self._output_transform = _transforms.compose(output_transforms)
|
||||
self._rng = rng or jax.random.key(0)
|
||||
self._sample_kwargs = sample_kwargs or {}
|
||||
self._metadata = metadata or {}
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict) -> dict: # type: ignore[misc]
|
||||
# Make a copy since transformations may modify the inputs in place.
|
||||
inputs = jax.tree.map(lambda x: x, obs)
|
||||
inputs = self._input_transform(inputs)
|
||||
# Make a batch and convert to jax.Array.
|
||||
inputs = jax.tree.map(lambda x: jnp.asarray(x)[np.newaxis, ...], inputs)
|
||||
|
||||
self._rng, sample_rng = jax.random.split(self._rng)
|
||||
outputs = {
|
||||
"state": inputs["state"],
|
||||
"actions": self._sample_actions(sample_rng, _model.Observation.from_dict(inputs), **self._sample_kwargs),
|
||||
}
|
||||
|
||||
# Unbatch and convert to np.ndarray.
|
||||
outputs = jax.tree.map(lambda x: np.asarray(x[0, ...]), outputs)
|
||||
return self._output_transform(outputs)
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return self._metadata
|
||||
|
||||
|
||||
class PolicyRecorder(_base_policy.BasePolicy):
|
||||
"""Records the policy's behavior to disk."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, record_dir: str):
|
||||
self._policy = policy
|
||||
|
||||
logging.info(f"Dumping policy records to: {record_dir}")
|
||||
self._record_dir = pathlib.Path(record_dir)
|
||||
self._record_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._record_step = 0
|
||||
|
||||
@override
|
||||
def infer(self, obs: dict) -> dict: # type: ignore[misc]
|
||||
results = self._policy.infer(obs)
|
||||
|
||||
data = {"inputs": obs, "outputs": results}
|
||||
data = flax.traverse_util.flatten_dict(data, sep="/")
|
||||
|
||||
output_path = self._record_dir / f"step_{self._record_step}"
|
||||
self._record_step += 1
|
||||
|
||||
np.save(output_path, np.asarray(data))
|
||||
return results
|
||||
83
src/openpi/policies/policy_config.py
Normal file
83
src/openpi/policies/policy_config.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.policies.policy as _policy
|
||||
import openpi.shared.download as download
|
||||
from openpi.training import checkpoints as _checkpoints
|
||||
from openpi.training import config as _config
|
||||
import openpi.transforms as transforms
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PolicyConfig:
|
||||
model: _model.BaseModel
|
||||
norm_stats: dict[str, transforms.NormStats]
|
||||
|
||||
input_layers: Sequence[transforms.DataTransformFn]
|
||||
output_layers: Sequence[transforms.DataTransformFn]
|
||||
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
default_prompt: str | None = None
|
||||
sample_kwargs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
def create_trained_policy(
|
||||
train_config: _config.TrainConfig,
|
||||
checkpoint_dir: pathlib.Path | str,
|
||||
*,
|
||||
repack_transforms: transforms.Group | None = None,
|
||||
sample_kwargs: dict[str, Any] | None = None,
|
||||
default_prompt: str | None = None,
|
||||
norm_stats: dict[str, transforms.NormStats] | None = None,
|
||||
) -> _policy.Policy:
|
||||
"""Create a policy from a trained checkpoint.
|
||||
|
||||
Args:
|
||||
train_config: The training config to use to create the model.
|
||||
checkpoint_dir: The directory to load the model from.
|
||||
repack_transforms: Optional transforms that will be applied before any other transforms.
|
||||
sample_kwargs: The kwargs to pass to the `sample_actions` method. If not provided, the default
|
||||
kwargs will be used.
|
||||
default_prompt: The default prompt to use for the policy. Will inject the prompt into the input
|
||||
data if it doesn't already exist.
|
||||
norm_stats: The norm stats to use for the policy. If not provided, the norm stats will be loaded
|
||||
from the checkpoint directory.
|
||||
"""
|
||||
repack_transforms = repack_transforms or transforms.Group()
|
||||
checkpoint_dir = download.maybe_download(str(checkpoint_dir))
|
||||
|
||||
logging.info("Loading model...")
|
||||
model = train_config.model.load(_model.restore_params(checkpoint_dir / "params", dtype=jnp.bfloat16))
|
||||
|
||||
data_config = train_config.data.create(train_config.assets_dirs, train_config.model)
|
||||
if norm_stats is None:
|
||||
# We are loading the norm stats from the checkpoint instead of the config assets dir to make sure
|
||||
# that the policy is using the same normalization stats as the original training process.
|
||||
if data_config.asset_id is None:
|
||||
raise ValueError("Asset id is required to load norm stats.")
|
||||
norm_stats = _checkpoints.load_norm_stats(checkpoint_dir / "assets", data_config.asset_id)
|
||||
|
||||
return _policy.Policy(
|
||||
model,
|
||||
transforms=[
|
||||
*repack_transforms.inputs,
|
||||
transforms.InjectDefaultPrompt(default_prompt),
|
||||
*data_config.data_transforms.inputs,
|
||||
transforms.Normalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.model_transforms.inputs,
|
||||
],
|
||||
output_transforms=[
|
||||
*data_config.model_transforms.outputs,
|
||||
transforms.Unnormalize(norm_stats, use_quantiles=data_config.use_quantile_norm),
|
||||
*data_config.data_transforms.outputs,
|
||||
*repack_transforms.outputs,
|
||||
],
|
||||
sample_kwargs=sample_kwargs,
|
||||
metadata=train_config.policy_metadata,
|
||||
)
|
||||
34
src/openpi/policies/policy_test.py
Normal file
34
src/openpi/policies/policy_test.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from openpi_client import action_chunk_broker
|
||||
import pytest
|
||||
|
||||
from openpi.policies import aloha_policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_infer():
|
||||
config = _config.get_config("pi0_aloha_sim")
|
||||
policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim")
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
result = policy.infer(example)
|
||||
|
||||
assert result["actions"].shape == (config.model.action_horizon, 14)
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_broker():
|
||||
config = _config.get_config("pi0_aloha_sim")
|
||||
policy = _policy_config.create_trained_policy(config, "s3://openpi-assets/checkpoints/pi0_aloha_sim")
|
||||
|
||||
broker = action_chunk_broker.ActionChunkBroker(
|
||||
policy,
|
||||
# Only execute the first half of the chunk.
|
||||
action_horizon=config.model.action_horizon // 2,
|
||||
)
|
||||
|
||||
example = aloha_policy.make_aloha_example()
|
||||
for _ in range(config.model.action_horizon):
|
||||
outputs = broker.infer(example)
|
||||
assert outputs["actions"].shape == (14,)
|
||||
0
src/openpi/py.typed
Normal file
0
src/openpi/py.typed
Normal file
63
src/openpi/serving/websocket_policy_server.py
Normal file
63
src/openpi/serving/websocket_policy_server.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
import websockets.asyncio.server
|
||||
import websockets.frames
|
||||
|
||||
|
||||
class WebsocketPolicyServer:
|
||||
"""Serves a policy using the websocket protocol. See websocket_client_policy.py for a client implementation.
|
||||
|
||||
Currently only implements the `load` and `infer` methods.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: _base_policy.BasePolicy,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8000,
|
||||
metadata: dict | None = None,
|
||||
) -> None:
|
||||
self._policy = policy
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._metadata = metadata or {}
|
||||
logging.getLogger("websockets.server").setLevel(logging.INFO)
|
||||
|
||||
def serve_forever(self) -> None:
|
||||
asyncio.run(self.run())
|
||||
|
||||
async def run(self):
|
||||
async with websockets.asyncio.server.serve(
|
||||
self._handler,
|
||||
self._host,
|
||||
self._port,
|
||||
compression=None,
|
||||
max_size=None,
|
||||
) as server:
|
||||
await server.serve_forever()
|
||||
|
||||
async def _handler(self, websocket: websockets.asyncio.server.ServerConnection):
|
||||
logging.info(f"Connection from {websocket.remote_address} opened")
|
||||
packer = msgpack_numpy.Packer()
|
||||
|
||||
await websocket.send(packer.pack(self._metadata))
|
||||
|
||||
while True:
|
||||
try:
|
||||
obs = msgpack_numpy.unpackb(await websocket.recv())
|
||||
action = self._policy.infer(obs)
|
||||
await websocket.send(packer.pack(action))
|
||||
except websockets.ConnectionClosed:
|
||||
logging.info(f"Connection from {websocket.remote_address} closed")
|
||||
break
|
||||
except Exception:
|
||||
await websocket.send(traceback.format_exc())
|
||||
await websocket.close(
|
||||
code=websockets.frames.CloseCode.INTERNAL_ERROR,
|
||||
reason="Internal server error. Traceback included in previous frame.",
|
||||
)
|
||||
raise
|
||||
0
src/openpi/shared/__init__.py
Normal file
0
src/openpi/shared/__init__.py
Normal file
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user