multi-node openpi commit
This commit is contained in:
Submodule policy/openpi-InternData-A1 deleted from 10b4b8fd13
3
policy/openpi-InternData-A1/.dockerignore
Normal file
3
policy/openpi-InternData-A1/.dockerignore
Normal file
@@ -0,0 +1,3 @@
|
||||
.venv
|
||||
checkpoints
|
||||
data
|
||||
169
policy/openpi-InternData-A1/.gitignore
vendored
Normal file
169
policy/openpi-InternData-A1/.gitignore
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
# Data directories.
|
||||
assets/
|
||||
checkpoints/
|
||||
data/
|
||||
wandb/
|
||||
third_party/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
16
policy/openpi-InternData-A1/.pre-commit-config.yaml
Normal file
16
policy/openpi-InternData-A1/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,16 @@
|
||||
exclude: third_party/
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/uv-pre-commit
|
||||
# uv version.
|
||||
rev: 0.5.14
|
||||
hooks:
|
||||
- id: uv-lock
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.8.6
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
1
policy/openpi-InternData-A1/.python-version
Normal file
1
policy/openpi-InternData-A1/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.11
|
||||
33
policy/openpi-InternData-A1/CONTRIBUTING.md
Normal file
33
policy/openpi-InternData-A1/CONTRIBUTING.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Contributing to openpi
|
||||
|
||||
We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
|
||||
|
||||
## Issues and feature requests
|
||||
|
||||
You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
|
||||
|
||||
If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
|
||||
|
||||
- Your OS type and version and the version of Python you are using
|
||||
- Code that allows us to reproduce your bug, including all dependencies
|
||||
- Traceback of any exception
|
||||
- Any other information that would help us, such as a screenshot
|
||||
|
||||
In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
|
||||
|
||||
If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
|
||||
|
||||
- The motivation for the feature
|
||||
- A description of the problem you are trying to solve or your use case
|
||||
- Enough information for us to understand the nature of the request
|
||||
- Some information for how you intend to use it (this might help us in understanding the motivation!)
|
||||
|
||||
We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
|
||||
|
||||
## Submitting a pull request
|
||||
|
||||
If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
|
||||
|
||||
- Make sure that your PR has a clear title and description
|
||||
- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
|
||||
- Make sure your PR passes all tests
|
||||
201
policy/openpi-InternData-A1/LICENSE
Normal file
201
policy/openpi-InternData-A1/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
9
policy/openpi-InternData-A1/README.md
Normal file
9
policy/openpi-InternData-A1/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# openpi-InternData-A1
|
||||
|
||||
## Training
|
||||
|
||||
For detailed instructions on pretraining with InterData-A1, finetuning on real-world tasks and sim2real transfer experiments, please refer to [`docs/training.md`](docs/training.md).
|
||||
|
||||
|
||||
## Pretrained Checkpoints
|
||||
We pretrained Pi0 model in on InternData-A1 for 680k iterations, initialized from PaliGemma checkpoint. The resulting pretrained ckpt is available [here](https://huggingface.co/yuyinyang3y/interndata-a1).
|
||||
25
policy/openpi-InternData-A1/docs/docker.md
Normal file
25
policy/openpi-InternData-A1/docs/docker.md
Normal file
@@ -0,0 +1,25 @@
|
||||
### Docker Setup
|
||||
|
||||
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
|
||||
|
||||
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
|
||||
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
|
||||
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
|
||||
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
|
||||
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
|
||||
|
||||
|
||||
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
|
||||
|
||||
Build the Docker image and start the container with the following command:
|
||||
```bash
|
||||
docker compose -f scripts/docker/compose.yml up --build
|
||||
```
|
||||
|
||||
To build and run the Docker image for a specific example, use the following command:
|
||||
```bash
|
||||
docker compose -f examples/<example_name>/compose.yml up --build
|
||||
```
|
||||
where `<example_name>` is the name of the example you want to run.
|
||||
|
||||
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.
|
||||
179
policy/openpi-InternData-A1/docs/norm_stats.md
Normal file
179
policy/openpi-InternData-A1/docs/norm_stats.md
Normal file
@@ -0,0 +1,179 @@
|
||||
# Normalization Statistics
|
||||
|
||||
Here we provide instructions for computing **normalization statistics** for both **real-world**, **simulation (InternData-A1)** and **sim2real** tasks. The computed statistics are saved in JSON format and are intended to be reused during training and evaluation in the OpenPI pipeline.
|
||||
|
||||
Normalization is computed over:
|
||||
- `state`
|
||||
- `actions`
|
||||
|
||||
and follows the exact data preprocessing and repacking logic used during training.
|
||||
|
||||
---
|
||||
|
||||
## 1. Simulation Tasks (InternData-A1)
|
||||
This script `scripts/compute_norm_stats_sim.py` computes normalization statistics for simulation tasks in the InternData-A1 benchmark.
|
||||
|
||||
### Supported Robots
|
||||
- `split_aloha`
|
||||
- `lift2`
|
||||
- `genie1`
|
||||
- `franka`
|
||||
|
||||
### Dataset Structure
|
||||
Download the InternData-A1 datasets from [here](https://huggingface.co/datasets/InternRobotics/InternData-A1).
|
||||
The structure of the dataset is as follows:
|
||||
|
||||
```
|
||||
InternData-A1/sim/
|
||||
└── <task_category>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/ # no subtask
|
||||
├── data/
|
||||
├── meta/
|
||||
└── videos/
|
||||
```
|
||||
|
||||
Some tasks may have subtasks / collections:
|
||||
|
||||
```
|
||||
InternData-A1/sim/
|
||||
└── <task_category>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── <collect_name>/
|
||||
├── data/
|
||||
├── meta/
|
||||
└── videos/
|
||||
```
|
||||
|
||||
### Usage
|
||||
```
|
||||
python scripts/compute_norm_stats_sim.py \
|
||||
--root_data_dir InternData-A1/sim \
|
||||
--task_category pick_and_place_tasks \
|
||||
--save_path stats/sim \
|
||||
--start_ratio 0.0 \
|
||||
--end_ratio 1.0
|
||||
```
|
||||
|
||||
Arguments
|
||||
- `root_data_dir`: Root directory of simulation datasets.
|
||||
- `task_category`: Task category to process (e.g. pick_and_place_tasks).
|
||||
- `save_path`: Root directory where normalization statistics will be saved.
|
||||
- `start_ratio`, `end_ratio`: Fraction of tasks to process (useful for sharding large datasets).
|
||||
|
||||
### Output Structure
|
||||
```
|
||||
<save_path>/
|
||||
└── <task_category>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── <collect_name>/ # empty if no subtask
|
||||
└── norm_stats.json
|
||||
```
|
||||
During pretraining, set the `stats_dir` argument in `DataConfig` to the `save_path` here.
|
||||
|
||||
## 2. Real-World Tasks
|
||||
This script `scripts/compute_norm_stats_real.py` computes normalization statistics for real-world tasks.
|
||||
|
||||
### Supported Robots
|
||||
- `lift2`
|
||||
- `split_aloha`
|
||||
- `acone`
|
||||
- `genie1`
|
||||
|
||||
### Dataset Structure
|
||||
Real-world datasets are expected to follow the LeRobot repository structure:
|
||||
```
|
||||
InternData-A1/real/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── <collect_name>/ # empty if no subtask
|
||||
├── data/
|
||||
├── meta/
|
||||
└── videos/
|
||||
```
|
||||
|
||||
Example task path:
|
||||
```
|
||||
InternData-A1/real/genie1/
|
||||
└── Pick_a_bag_of_bread_with_the_left_arm__then_handover/set_0
|
||||
```
|
||||
|
||||
### Usage
|
||||
```
|
||||
python scripts/compute_norm_stats_real.py \
|
||||
--task_path InternData-A1/real/genie1/Pick_a_bag_of_bread_with_the_left_arm__then_handover/* \
|
||||
--robot_name genie1 \
|
||||
--save_path stats/real
|
||||
```
|
||||
|
||||
Arguments
|
||||
- `task_path`: Path (or glob pattern) to a real-world task dataset(e.g. `InternData-A1/real/genie1/Pick_a_bag_of_bread_with_the_left_arm__then_handover/*`)
|
||||
- `robot_name`: Robot platform name (must be supported).
|
||||
- `save_path`: Root directory where normalization statistics will be saved.
|
||||
|
||||
### Output Structure
|
||||
```
|
||||
<save_path>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── norm_stats.json
|
||||
```
|
||||
During finetuning, set the `fixed_stats_dir` argument in `DataConfig` to `<save_path>/<robot_name>/<task_name>` here.
|
||||
|
||||
## 3. Sim2Real Experiments
|
||||
This script `scripts/compute_norm_stats_sim2real.py` computes normalization statistics for sim2real experiments.
|
||||
|
||||
### Supported Robots
|
||||
- `lift2`
|
||||
|
||||
### Dataset Structure
|
||||
Dataset from InternData-A1 are expected to follow the LeRobot repository structure:
|
||||
```
|
||||
InternData-A1/sim/
|
||||
└── <task_category>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── <collect_name>/
|
||||
├── data/
|
||||
├── meta/
|
||||
└── videos/
|
||||
```
|
||||
|
||||
Example task path:
|
||||
```
|
||||
InternData-A1/sim/long_horizon_tasks/lift2/
|
||||
└── sort_the_rubbish
|
||||
└── Sort_rubbish_1l2r
|
||||
└── Sort_rubbish_2l1r
|
||||
└── Sort_rubbish_2l2r
|
||||
```
|
||||
|
||||
### Usage
|
||||
```
|
||||
python scripts/compute_norm_stats_sim2real.py \
|
||||
--task_path InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/* \
|
||||
--robot_name lift2 \
|
||||
--save_path stats/sim2real
|
||||
```
|
||||
|
||||
Arguments
|
||||
- `task_path`: Path (or glob pattern) to a task dataset(e.g. `InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/*` means training on all the collections in the task)
|
||||
- `robot_name`: Robot platform name (we only support `lift2` for now, but you can try other robots).
|
||||
- `save_path`: Root directory where normalization statistics will be saved.
|
||||
|
||||
### Output Structure
|
||||
```
|
||||
<save_path>/
|
||||
└── <robot_name>/
|
||||
└── <task_name>/
|
||||
└── norm_stats.json
|
||||
```
|
||||
During finetuning, set the `fixed_stats_dir` argument in `DataConfig` to `<save_path>/<robot_name>/<task_name>` here.
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
For simulation tasks and sim2real experiments, computation may stop early (e.g. after 10k steps) to limit runtime.
|
||||
|
||||
For sim2real transfer, we set the gripper dimension in the state vector to zero because the state of the gripper in the real world during inference is not aligned with the state in the simulation. See `src/openpi/policies/sim2real_split_aloha_policy.py` for more details.
|
||||
71
policy/openpi-InternData-A1/docs/remote_inference.md
Normal file
71
policy/openpi-InternData-A1/docs/remote_inference.md
Normal file
@@ -0,0 +1,71 @@
|
||||
|
||||
# Running openpi models remotely
|
||||
|
||||
We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
|
||||
|
||||
## Starting a remote policy server
|
||||
|
||||
To start a remote policy server, you can simply run the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
|
||||
```
|
||||
|
||||
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
||||
```
|
||||
|
||||
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
|
||||
|
||||
## Querying the remote policy server from your robot code
|
||||
|
||||
We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
|
||||
|
||||
First, install the `openpi-client` package in your robot environment:
|
||||
|
||||
```bash
|
||||
cd $OPENPI_ROOT/packages/openpi-client
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
|
||||
|
||||
```python
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
|
||||
# Outside of episode loop, initialize the policy client.
|
||||
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
|
||||
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
|
||||
|
||||
for step in range(num_steps):
|
||||
# Inside the episode loop, construct the observation.
|
||||
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
|
||||
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
|
||||
# The typical resize_size for pre-trained pi0 models is 224.
|
||||
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
|
||||
observation = {
|
||||
"observation/image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, 224, 224)
|
||||
),
|
||||
"observation/wrist_image": image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, 224, 224)
|
||||
),
|
||||
"observation/state": state,
|
||||
"prompt": task_instruction,
|
||||
}
|
||||
|
||||
# Call the policy server with the current observation.
|
||||
# This returns an action chunk of shape (action_horizon, action_dim).
|
||||
# Note that you typically only need to call the policy every N steps and execute steps
|
||||
# from the predicted action chunk open-loop in the remaining steps.
|
||||
action_chunk = client.infer(observation)["actions"]
|
||||
|
||||
# Execute the actions in the environment.
|
||||
...
|
||||
|
||||
```
|
||||
|
||||
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).
|
||||
102
policy/openpi-InternData-A1/docs/training.md
Normal file
102
policy/openpi-InternData-A1/docs/training.md
Normal file
@@ -0,0 +1,102 @@
|
||||
# Training Instructions
|
||||
|
||||
Here we provide instructions for pretraining on InternData-A1, finetuning on real-world tasks and finetuning on InternData-A1 tasks for sim2real transfer.
|
||||
|
||||
Before training, you need to compute the normalization statistics for the tasks you want to train on. Please refer to [norm_stats.md](norm_stats.md) for more details.
|
||||
|
||||
---
|
||||
|
||||
## 1. Pretraining on InternData-A1
|
||||
|
||||
|
||||
### Write a training config
|
||||
We provide a `TrainConfig` example named `pretrain-interndata-a1` in `src/openpi/training/config.py`.
|
||||
InternData-A1 contains four robot embodiments:
|
||||
- `split_aloha`
|
||||
- `lift2`
|
||||
- `genie1`
|
||||
- `franka`
|
||||
|
||||
Accordingly, we define three `MultiDataConfigFactory` classes:
|
||||
- `MultiSimSplitAlohaDataConfig` for `split_aloha` and `lift2`
|
||||
- `MultiSimGenieDataConfig` for `genie1`
|
||||
- `MultiSimFrankaDataConfig` for `franka`
|
||||
|
||||
Please either:
|
||||
- create a soft link from the InternData-A1 dataset to `data/InternData-A1`, or
|
||||
- modify the `repo_dir` field in all relevant `MultiDataConfig` entries to point to your local InternData-A1 path.
|
||||
|
||||
Set `stats_dir` to your local normalization statistics directory. If you use the default setting, ensure that the normalization statistics for simulation tasks are saved under `stats/sim`.
|
||||
|
||||
We initialize the model from PaliGemma-3B using:
|
||||
```
|
||||
weight_loader=weight_loaders.PaliGemmaWeightLoader("checkpoints/jax/paligemma/pt_224.npz")
|
||||
```
|
||||
Please download the PaliGemma-3b checkpoint by running
|
||||
```
|
||||
python scripts/download_paligemma.py
|
||||
```
|
||||
|
||||
You may adjust other training parameters based on your available GPUs and training budget:
|
||||
- `num_train_steps`: Total number of training steps
|
||||
- `num_workers`: Number of data loading workers
|
||||
- `fsdp_devices`: Number of GPUs per node
|
||||
- `batch_size`: Batch size per GPU
|
||||
- `save_interval`: Checkpoint saving interval (in steps)
|
||||
|
||||
### Run training
|
||||
For multi node training, run
|
||||
```
|
||||
bash scripts/training_scripts/multi_node.sh
|
||||
```
|
||||
|
||||
For single node multi-GPU training, run
|
||||
```
|
||||
config_name=pretrain-interndata-a1
|
||||
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
|
||||
```
|
||||
|
||||
The ckpts will be saved to `checkpoints/${config_name}`.
|
||||
|
||||
## 2. Finetuning on Real-World Tasks
|
||||
### Write a training config
|
||||
We provide a `TrainConfig` example named `finetune-a2d-pen` in `src/openpi/training/config.py`.
|
||||
|
||||
Key arguments you may need to modify include:
|
||||
- `MultiDataConfigFactory` class:
|
||||
- `MultiLeRobotReala2dDataConfig` for `genie1`
|
||||
- `MultiLeRobotRealArxLift2DataConfig` for `lift2` and `acone`
|
||||
- `repo_dir`: Path to the real-world task dataset.
|
||||
- `robot_name`: the robot name in `repo_dir`, e.g. "genie1".
|
||||
- `fixed_stats_dir`: Path to the normalization statistics for the real-world task. When this is set, statistics from `stats_dir` will not be used.
|
||||
- `weight_loader`: Pretrained checkpoint used for initialization.
|
||||
You may download our pretrained checkpoints from [here]().
|
||||
|
||||
### Run training
|
||||
For training, run
|
||||
For single node multi-GPU training, run
|
||||
```
|
||||
config_name=finetune-a2d-pen
|
||||
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
|
||||
```
|
||||
|
||||
The ckpts will be saved under `checkpoints/${config_name}`.
|
||||
|
||||
## 3. Finetuning on InternData-A1 Tasks for Sim2Real Transfer
|
||||
### Write a training config
|
||||
We provide a `TrainConfig` example named `finetune-sim2real-lift2-sort-rubbish` in `src/openpi/training/config.py`.
|
||||
|
||||
Key arguments you may need to modify include:
|
||||
- `MultiDataConfigFactory` class: Currently, sim-to-real transfer is evaluated only on `lift2` tasks:
|
||||
- `MultiSim2RealSplitAlohaDataConfig` for `lift2`
|
||||
- `repo_dir`: Path to the corresponding InternData-A1 task.
|
||||
- `fixed_stats_dir`: Path to the normalization statistics for the sim-to-real task. When specified, statistics from `stats_dir` will not be used.
|
||||
- `weight_loader`: Pretrained checkpoint used for initialization.
|
||||
|
||||
### Run training
|
||||
For training, run
|
||||
For single node multi-GPU training, run
|
||||
```
|
||||
config_name=finetune-sim2real-lift2-sort-rubbish
|
||||
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
|
||||
```
|
||||
70
policy/openpi-InternData-A1/examples/aloha_real/Dockerfile
Normal file
70
policy/openpi-InternData-A1/examples/aloha_real/Dockerfile
Normal file
@@ -0,0 +1,70 @@
|
||||
# Dockerfile for the Aloha real environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
|
||||
|
||||
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
cmake \
|
||||
curl \
|
||||
libffi-dev \
|
||||
python3-rosdep \
|
||||
python3-rosinstall \
|
||||
python3-rosinstall-generator \
|
||||
whiptail \
|
||||
git \
|
||||
wget \
|
||||
openssh-client \
|
||||
ros-noetic-cv-bridge \
|
||||
ros-noetic-usb-cam \
|
||||
ros-noetic-realsense2-camera \
|
||||
keyboard-configuration
|
||||
|
||||
WORKDIR /root
|
||||
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
|
||||
RUN chmod +x xsarm_amd64_install.sh
|
||||
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
|
||||
|
||||
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
|
||||
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
|
||||
|
||||
# Install python 3.10 because this ROS image comes with 3.8
|
||||
RUN mkdir /python && \
|
||||
cd /python && \
|
||||
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
|
||||
tar -zxvf Python-3.10.14.tgz && \
|
||||
cd Python-3.10.14 && \
|
||||
ls -lhR && \
|
||||
./configure --enable-optimizations && \
|
||||
make install && \
|
||||
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
|
||||
cd ~ && rm -rf /python && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
|
||||
ENV UV_HTTP_TIMEOUT=120
|
||||
ENV UV_LINK_MODE=copy
|
||||
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
|
||||
WORKDIR /app
|
||||
|
||||
# Create an entrypoint script to run the setup commands, followed by the command passed in.
|
||||
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
|
||||
#!/bin/bash
|
||||
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
|
||||
EOF
|
||||
RUN chmod +x /usr/local/bin/entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
|
||||
CMD ["python3", "/app/examples/aloha_real/main.py"]
|
||||
126
policy/openpi-InternData-A1/examples/aloha_real/README.md
Normal file
126
policy/openpi-InternData-A1/examples/aloha_real/README.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Run Aloha (Real Robot)
|
||||
|
||||
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
|
||||
|
||||
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
|
||||
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
|
||||
docker compose -f examples/aloha_real/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_real/.venv
|
||||
source examples/aloha_real/.venv/bin/activate
|
||||
uv pip sync examples/aloha_real/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the robot
|
||||
python -m examples.aloha_real.main
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
roslaunch aloha ros_nodes.launch
|
||||
```
|
||||
|
||||
Terminal window 3:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
|
||||
```
|
||||
|
||||
## **ALOHA Checkpoint Guide**
|
||||
|
||||
|
||||
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
|
||||
|
||||
While we’ve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects we’ve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
|
||||
|
||||
|
||||
---
|
||||
|
||||
### **Toast Task**
|
||||
|
||||
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
|
||||
- **Prompt**: "take the toast out of the toaster"
|
||||
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
|
||||
- **Object Distribution**:
|
||||
- Works on both real toast and rubber fake toast
|
||||
- Compatible with standard 2-slice toasters
|
||||
- Works with plates of varying colors
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
|
||||
|
||||
- The toaster should be positioned in the top-left quadrant of the workspace.
|
||||
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
|
||||
- The plate should be placed roughly in the lower-center of the workspace.
|
||||
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
|
||||
|
||||
|
||||
### **Towel Task**
|
||||
|
||||
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
|
||||
- **Prompt**: "fold the towel"
|
||||
- **Object Distribution**:
|
||||
- Works on towels of varying solid colors
|
||||
- Performance is worse on heavily textured or striped towels
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
|
||||
|
||||
- The towel should be flattened and roughly centered on the table.
|
||||
- Choose a towel that does not blend in with the table surface.
|
||||
|
||||
|
||||
### **Tupperware Task**
|
||||
|
||||
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
|
||||
|
||||
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
|
||||
- **Prompt**: "open the tupperware and put the food on the plate"
|
||||
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
|
||||
- **Object Distribution**:
|
||||
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
|
||||
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
|
||||
- The policy has seen plates of varying solid colors.
|
||||
|
||||
### **Scene Setup Guidelines**
|
||||
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
|
||||
|
||||
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
|
||||
- Positioning:
|
||||
- Tupperware should be on the left.
|
||||
- Plate should be on the right or bottom.
|
||||
- The tupperware flap should point toward the plate.
|
||||
|
||||
## Training on your own Aloha dataset
|
||||
|
||||
1. Convert the dataset to the LeRobot dataset v2.0 format.
|
||||
|
||||
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
|
||||
|
||||
|
||||
2. Define a training config that uses the custom dataset.
|
||||
|
||||
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
|
||||
|
||||
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoint’s asset directory within the AssetsConfig.
|
||||
66
policy/openpi-InternData-A1/examples/aloha_real/compose.yml
Normal file
66
policy/openpi-InternData-A1/examples/aloha_real/compose.yml
Normal file
@@ -0,0 +1,66 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_real/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- aloha_ros_nodes
|
||||
- ros_master
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
aloha_ros_nodes:
|
||||
image: aloha_real
|
||||
depends_on:
|
||||
- ros_master
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_real/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- /dev:/dev
|
||||
command: roslaunch --wait aloha ros_nodes.launch
|
||||
|
||||
ros_master:
|
||||
image: ros:noetic-robot
|
||||
network_mode: host
|
||||
privileged: true
|
||||
command:
|
||||
- roscore
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
71
policy/openpi-InternData-A1/examples/aloha_real/constants.py
Normal file
71
policy/openpi-InternData-A1/examples/aloha_real/constants.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
|
||||
### Task parameters
|
||||
|
||||
### ALOHA fixed constants
|
||||
DT = 0.001
|
||||
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
|
||||
|
||||
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
|
||||
MASTER_GRIPPER_POSITION_OPEN = 0.02417
|
||||
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
|
||||
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
|
||||
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
|
||||
|
||||
# Gripper joint limits (qpos[6])
|
||||
MASTER_GRIPPER_JOINT_OPEN = 0.3083
|
||||
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
|
||||
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
|
||||
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
|
||||
|
||||
############################ Helper functions ############################
|
||||
|
||||
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
|
||||
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
|
||||
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
|
||||
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
|
||||
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
|
||||
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
|
||||
|
||||
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
|
||||
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
|
||||
|
||||
MASTER_POS2JOINT = (
|
||||
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
+ MASTER_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
PUPPET_POS2JOINT = (
|
||||
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
+ PUPPET_GRIPPER_JOINT_CLOSE
|
||||
)
|
||||
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
|
||||
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
|
||||
)
|
||||
|
||||
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2
|
||||
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||
|
||||
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class DatasetConfig:
|
||||
use_videos: bool = True
|
||||
tolerance_s: float = 0.0001
|
||||
image_writer_processes: int = 10
|
||||
image_writer_threads: int = 5
|
||||
video_backend: str | None = None
|
||||
|
||||
|
||||
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||
|
||||
|
||||
def create_empty_dataset(
|
||||
repo_id: str,
|
||||
robot_type: str,
|
||||
mode: Literal["video", "image"] = "video",
|
||||
*,
|
||||
has_velocity: bool = False,
|
||||
has_effort: bool = False,
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
) -> LeRobotDataset:
|
||||
motors = [
|
||||
"right_waist",
|
||||
"right_shoulder",
|
||||
"right_elbow",
|
||||
"right_forearm_roll",
|
||||
"right_wrist_angle",
|
||||
"right_wrist_rotate",
|
||||
"right_gripper",
|
||||
"left_waist",
|
||||
"left_shoulder",
|
||||
"left_elbow",
|
||||
"left_forearm_roll",
|
||||
"left_wrist_angle",
|
||||
"left_wrist_rotate",
|
||||
"left_gripper",
|
||||
]
|
||||
cameras = [
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
]
|
||||
|
||||
features = {
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
if has_velocity:
|
||||
features["observation.velocity"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
if has_effort:
|
||||
features["observation.effort"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (len(motors),),
|
||||
"names": [
|
||||
motors,
|
||||
],
|
||||
}
|
||||
|
||||
for cam in cameras:
|
||||
features[f"observation.images.{cam}"] = {
|
||||
"dtype": mode,
|
||||
"shape": (3, 480, 640),
|
||||
"names": [
|
||||
"channels",
|
||||
"height",
|
||||
"width",
|
||||
],
|
||||
}
|
||||
|
||||
if Path(LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
return LeRobotDataset.create(
|
||||
repo_id=repo_id,
|
||||
fps=50,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
use_videos=dataset_config.use_videos,
|
||||
tolerance_s=dataset_config.tolerance_s,
|
||||
image_writer_processes=dataset_config.image_writer_processes,
|
||||
image_writer_threads=dataset_config.image_writer_threads,
|
||||
video_backend=dataset_config.video_backend,
|
||||
)
|
||||
|
||||
|
||||
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
# ignore depth channel, not currently handled
|
||||
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||
|
||||
|
||||
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/qvel" in ep
|
||||
|
||||
|
||||
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||
with h5py.File(hdf5_files[0], "r") as ep:
|
||||
return "/observations/effort" in ep
|
||||
|
||||
|
||||
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||
imgs_per_cam = {}
|
||||
for camera in cameras:
|
||||
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||
|
||||
if uncompressed:
|
||||
# load all images in RAM
|
||||
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||
else:
|
||||
import cv2
|
||||
|
||||
# load one compressed image after the other in RAM and uncompress
|
||||
imgs_array = []
|
||||
for data in ep[f"/observations/images/{camera}"]:
|
||||
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||
imgs_array = np.array(imgs_array)
|
||||
|
||||
imgs_per_cam[camera] = imgs_array
|
||||
return imgs_per_cam
|
||||
|
||||
|
||||
def load_raw_episode_data(
|
||||
ep_path: Path,
|
||||
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
with h5py.File(ep_path, "r") as ep:
|
||||
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||
action = torch.from_numpy(ep["/action"][:])
|
||||
|
||||
velocity = None
|
||||
if "/observations/qvel" in ep:
|
||||
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||
|
||||
effort = None
|
||||
if "/observations/effort" in ep:
|
||||
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||
|
||||
imgs_per_cam = load_raw_images_per_camera(
|
||||
ep,
|
||||
[
|
||||
"cam_high",
|
||||
"cam_low",
|
||||
"cam_left_wrist",
|
||||
"cam_right_wrist",
|
||||
],
|
||||
)
|
||||
|
||||
return imgs_per_cam, state, action, velocity, effort
|
||||
|
||||
|
||||
def populate_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
hdf5_files: list[Path],
|
||||
task: str,
|
||||
episodes: list[int] | None = None,
|
||||
) -> LeRobotDataset:
|
||||
if episodes is None:
|
||||
episodes = range(len(hdf5_files))
|
||||
|
||||
for ep_idx in tqdm.tqdm(episodes):
|
||||
ep_path = hdf5_files[ep_idx]
|
||||
|
||||
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||
num_frames = state.shape[0]
|
||||
|
||||
for i in range(num_frames):
|
||||
frame = {
|
||||
"observation.state": state[i],
|
||||
"action": action[i],
|
||||
}
|
||||
|
||||
for camera, img_array in imgs_per_cam.items():
|
||||
frame[f"observation.images.{camera}"] = img_array[i]
|
||||
|
||||
if velocity is not None:
|
||||
frame["observation.velocity"] = velocity[i]
|
||||
if effort is not None:
|
||||
frame["observation.effort"] = effort[i]
|
||||
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode(task=task)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def port_aloha(
|
||||
raw_dir: Path,
|
||||
repo_id: str,
|
||||
raw_repo_id: str | None = None,
|
||||
task: str = "DEBUG",
|
||||
*,
|
||||
episodes: list[int] | None = None,
|
||||
push_to_hub: bool = True,
|
||||
is_mobile: bool = False,
|
||||
mode: Literal["video", "image"] = "image",
|
||||
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||
):
|
||||
if (LEROBOT_HOME / repo_id).exists():
|
||||
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||
|
||||
if not raw_dir.exists():
|
||||
if raw_repo_id is None:
|
||||
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||
|
||||
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||
|
||||
dataset = create_empty_dataset(
|
||||
repo_id,
|
||||
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||
mode=mode,
|
||||
has_effort=has_effort(hdf5_files),
|
||||
has_velocity=has_velocity(hdf5_files),
|
||||
dataset_config=dataset_config,
|
||||
)
|
||||
dataset = populate_dataset(
|
||||
dataset,
|
||||
hdf5_files,
|
||||
task=task,
|
||||
episodes=episodes,
|
||||
)
|
||||
dataset.consolidate()
|
||||
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(port_aloha)
|
||||
57
policy/openpi-InternData-A1/examples/aloha_real/env.py
Normal file
57
policy/openpi-InternData-A1/examples/aloha_real/env.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from typing import List, Optional # noqa: UP035
|
||||
|
||||
import einops
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
from examples.aloha_real import real_env as _real_env
|
||||
|
||||
|
||||
class AlohaRealEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot on real hardware."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
|
||||
render_height: int = 224,
|
||||
render_width: int = 224,
|
||||
) -> None:
|
||||
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
|
||||
self._render_height = render_height
|
||||
self._render_width = render_width
|
||||
|
||||
self._ts = None
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._ts = self._env.reset()
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return False
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._ts is None:
|
||||
raise RuntimeError("Timestep is not set. Call reset() first.")
|
||||
|
||||
obs = self._ts.observation
|
||||
for k in list(obs["images"].keys()):
|
||||
if "_depth" in k:
|
||||
del obs["images"][k]
|
||||
|
||||
for cam_name in obs["images"]:
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
|
||||
)
|
||||
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
|
||||
|
||||
return {
|
||||
"state": obs["qpos"],
|
||||
"images": obs["images"],
|
||||
}
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
self._ts = self._env.step(action["actions"])
|
||||
51
policy/openpi-InternData-A1/examples/aloha_real/main.py
Normal file
51
policy/openpi-InternData-A1/examples/aloha_real/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import tyro
|
||||
|
||||
from examples.aloha_real import env as _env
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
action_horizon: int = 25
|
||||
|
||||
num_episodes: int = 1
|
||||
max_episode_steps: int = 1000
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
)
|
||||
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
|
||||
|
||||
metadata = ws_client_policy.get_server_metadata()
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=ws_client_policy,
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[],
|
||||
max_hz=50,
|
||||
num_episodes=args.num_episodes,
|
||||
max_episode_steps=args.max_episode_steps,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
176
policy/openpi-InternData-A1/examples/aloha_real/real_env.py
Normal file
176
policy/openpi-InternData-A1/examples/aloha_real/real_env.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
import collections
|
||||
import time
|
||||
from typing import Optional, List
|
||||
import dm_env
|
||||
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
|
||||
from examples.aloha_real import constants
|
||||
from examples.aloha_real import robot_utils
|
||||
|
||||
# This is the reset position that is used by the standard Aloha runtime.
|
||||
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
|
||||
|
||||
|
||||
class RealEnv:
|
||||
"""
|
||||
Environment for real robot bi-manual manipulation
|
||||
Action space: [left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||
|
||||
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||
right_arm_qpos (6), # absolute joint position
|
||||
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
|
||||
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
|
||||
"""
|
||||
|
||||
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
|
||||
# reset_position = START_ARM_POSE[:6]
|
||||
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
|
||||
|
||||
self.puppet_bot_left = InterbotixManipulatorXS(
|
||||
robot_model="vx300s",
|
||||
group_name="arm",
|
||||
gripper_name="gripper",
|
||||
robot_name="puppet_left",
|
||||
init_node=init_node,
|
||||
)
|
||||
self.puppet_bot_right = InterbotixManipulatorXS(
|
||||
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
|
||||
)
|
||||
if setup_robots:
|
||||
self.setup_robots()
|
||||
|
||||
self.recorder_left = robot_utils.Recorder("left", init_node=False)
|
||||
self.recorder_right = robot_utils.Recorder("right", init_node=False)
|
||||
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
|
||||
self.gripper_command = JointSingleCommand(name="gripper")
|
||||
|
||||
def setup_robots(self):
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_left)
|
||||
robot_utils.setup_puppet_bot(self.puppet_bot_right)
|
||||
|
||||
def get_qpos(self):
|
||||
left_qpos_raw = self.recorder_left.qpos
|
||||
right_qpos_raw = self.recorder_right.qpos
|
||||
left_arm_qpos = left_qpos_raw[:6]
|
||||
right_arm_qpos = right_qpos_raw[:6]
|
||||
left_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
right_gripper_qpos = [
|
||||
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
|
||||
] # this is position not joint
|
||||
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||
|
||||
def get_qvel(self):
|
||||
left_qvel_raw = self.recorder_left.qvel
|
||||
right_qvel_raw = self.recorder_right.qvel
|
||||
left_arm_qvel = left_qvel_raw[:6]
|
||||
right_arm_qvel = right_qvel_raw[:6]
|
||||
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
|
||||
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
|
||||
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||
|
||||
def get_effort(self):
|
||||
left_effort_raw = self.recorder_left.effort
|
||||
right_effort_raw = self.recorder_right.effort
|
||||
left_robot_effort = left_effort_raw[:7]
|
||||
right_robot_effort = right_effort_raw[:7]
|
||||
return np.concatenate([left_robot_effort, right_robot_effort])
|
||||
|
||||
def get_images(self):
|
||||
return self.image_recorder.get_images()
|
||||
|
||||
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
|
||||
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
|
||||
self.gripper_command.cmd = left_gripper_desired_joint
|
||||
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
|
||||
right_gripper_desired_pos_normalized
|
||||
)
|
||||
self.gripper_command.cmd = right_gripper_desired_joint
|
||||
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
|
||||
|
||||
def _reset_joints(self):
|
||||
robot_utils.move_arms(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
|
||||
)
|
||||
|
||||
def _reset_gripper(self):
|
||||
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
|
||||
|
||||
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
|
||||
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
|
||||
increase the frequency of motor faults.
|
||||
"""
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
|
||||
)
|
||||
robot_utils.move_grippers(
|
||||
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
|
||||
)
|
||||
|
||||
def get_observation(self):
|
||||
obs = collections.OrderedDict()
|
||||
obs["qpos"] = self.get_qpos()
|
||||
obs["qvel"] = self.get_qvel()
|
||||
obs["effort"] = self.get_effort()
|
||||
obs["images"] = self.get_images()
|
||||
return obs
|
||||
|
||||
def get_reward(self):
|
||||
return 0
|
||||
|
||||
def reset(self, *, fake=False):
|
||||
if not fake:
|
||||
# Reboot puppet robot gripper motors
|
||||
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
self._reset_joints()
|
||||
self._reset_gripper()
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
def step(self, action):
|
||||
state_len = int(len(action) / 2)
|
||||
left_action = action[:state_len]
|
||||
right_action = action[state_len:]
|
||||
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
|
||||
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
|
||||
self.set_gripper_pose(left_action[-1], right_action[-1])
|
||||
time.sleep(constants.DT)
|
||||
return dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
|
||||
)
|
||||
|
||||
|
||||
def get_action(master_bot_left, master_bot_right):
|
||||
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
|
||||
# Arm actions
|
||||
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
|
||||
# Gripper actions
|
||||
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
|
||||
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
|
||||
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)
|
||||
@@ -0,0 +1,18 @@
|
||||
Pillow
|
||||
dm_control
|
||||
einops
|
||||
h5py
|
||||
matplotlib
|
||||
modern_robotics
|
||||
msgpack
|
||||
numpy>=1.22.4,<2.0.0
|
||||
opencv-python
|
||||
packaging
|
||||
pexpect
|
||||
pyquaternion
|
||||
pyrealsense2
|
||||
pyyaml
|
||||
requests
|
||||
rospkg
|
||||
tyro
|
||||
websockets
|
||||
156
policy/openpi-InternData-A1/examples/aloha_real/requirements.txt
Normal file
156
policy/openpi-InternData-A1/examples/aloha_real/requirements.txt
Normal file
@@ -0,0 +1,156 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
catkin-pkg==1.0.0
|
||||
# via rospkg
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
distro==1.9.0
|
||||
# via rospkg
|
||||
dm-control==1.0.23
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
docutils==0.20.1
|
||||
# via catkin-pkg
|
||||
einops==0.8.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
h5py==3.11.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
idna==3.10
|
||||
# via requests
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.7.5
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
modern-robotics==1.1.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
mujoco==3.2.3
|
||||
# via dm-control
|
||||
numpy==1.24.4
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# h5py
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# modern-robotics
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# pyquaternion
|
||||
# scipy
|
||||
opencv-python==4.10.0.84
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
packaging==24.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
pexpect==4.9.0
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.1.4
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyquaternion==0.9.9
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
pyrealsense2==2.55.1.6486
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# matplotlib
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# rospkg
|
||||
requests==2.32.3
|
||||
# via
|
||||
# -r examples/aloha_real/requirements.in
|
||||
# dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
rospkg==1.5.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
scipy==1.10.1
|
||||
# via dm-control
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# catkin-pkg
|
||||
# dm-control
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_real/requirements.in
|
||||
zipp==3.20.2
|
||||
# via etils
|
||||
275
policy/openpi-InternData-A1/examples/aloha_real/robot_utils.py
Normal file
275
policy/openpi-InternData-A1/examples/aloha_real/robot_utils.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
|
||||
# ruff: noqa
|
||||
from collections import deque
|
||||
import datetime
|
||||
import json
|
||||
import time
|
||||
|
||||
from aloha.msg import RGBGrayscaleImage
|
||||
from cv_bridge import CvBridge
|
||||
from interbotix_xs_msgs.msg import JointGroupCommand
|
||||
from interbotix_xs_msgs.msg import JointSingleCommand
|
||||
import numpy as np
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
from examples.aloha_real import constants
|
||||
|
||||
|
||||
class ImageRecorder:
|
||||
def __init__(self, init_node=True, is_debug=False):
|
||||
self.is_debug = is_debug
|
||||
self.bridge = CvBridge()
|
||||
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("image_recorder", anonymous=True)
|
||||
for cam_name in self.camera_names:
|
||||
setattr(self, f"{cam_name}_rgb_image", None)
|
||||
setattr(self, f"{cam_name}_depth_image", None)
|
||||
setattr(self, f"{cam_name}_timestamp", 0.0)
|
||||
if cam_name == "cam_high":
|
||||
callback_func = self.image_cb_cam_high
|
||||
elif cam_name == "cam_low":
|
||||
callback_func = self.image_cb_cam_low
|
||||
elif cam_name == "cam_left_wrist":
|
||||
callback_func = self.image_cb_cam_left_wrist
|
||||
elif cam_name == "cam_right_wrist":
|
||||
callback_func = self.image_cb_cam_right_wrist
|
||||
else:
|
||||
raise NotImplementedError
|
||||
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
|
||||
if self.is_debug:
|
||||
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
|
||||
|
||||
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
|
||||
time.sleep(0.5)
|
||||
|
||||
def image_cb(self, cam_name, data):
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_rgb_image",
|
||||
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
|
||||
)
|
||||
# setattr(
|
||||
# self,
|
||||
# f"{cam_name}_depth_image",
|
||||
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
|
||||
# )
|
||||
setattr(
|
||||
self,
|
||||
f"{cam_name}_timestamp",
|
||||
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
|
||||
)
|
||||
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
|
||||
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
|
||||
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
|
||||
if self.is_debug:
|
||||
getattr(self, f"{cam_name}_timestamps").append(
|
||||
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
|
||||
)
|
||||
|
||||
def image_cb_cam_high(self, data):
|
||||
cam_name = "cam_high"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_low(self, data):
|
||||
cam_name = "cam_low"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_left_wrist(self, data):
|
||||
cam_name = "cam_left_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def image_cb_cam_right_wrist(self, data):
|
||||
cam_name = "cam_right_wrist"
|
||||
return self.image_cb(cam_name, data)
|
||||
|
||||
def get_images(self):
|
||||
image_dict = {}
|
||||
for cam_name in self.camera_names:
|
||||
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
|
||||
time.sleep(0.00001)
|
||||
rgb_image = getattr(self, f"{cam_name}_rgb_image")
|
||||
depth_image = getattr(self, f"{cam_name}_depth_image")
|
||||
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
|
||||
image_dict[cam_name] = rgb_image
|
||||
image_dict[f"{cam_name}_depth"] = depth_image
|
||||
return image_dict
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
|
||||
print(f"{cam_name} {image_freq=:.2f}")
|
||||
print()
|
||||
|
||||
|
||||
class Recorder:
|
||||
def __init__(self, side, init_node=True, is_debug=False):
|
||||
self.secs = None
|
||||
self.nsecs = None
|
||||
self.qpos = None
|
||||
self.effort = None
|
||||
self.arm_command = None
|
||||
self.gripper_command = None
|
||||
self.is_debug = is_debug
|
||||
|
||||
if init_node:
|
||||
rospy.init_node("recorder", anonymous=True)
|
||||
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_group",
|
||||
JointGroupCommand,
|
||||
self.puppet_arm_commands_cb,
|
||||
)
|
||||
rospy.Subscriber(
|
||||
f"/puppet_{side}/commands/joint_single",
|
||||
JointSingleCommand,
|
||||
self.puppet_gripper_commands_cb,
|
||||
)
|
||||
if self.is_debug:
|
||||
self.joint_timestamps = deque(maxlen=50)
|
||||
self.arm_command_timestamps = deque(maxlen=50)
|
||||
self.gripper_command_timestamps = deque(maxlen=50)
|
||||
time.sleep(0.1)
|
||||
|
||||
def puppet_state_cb(self, data):
|
||||
self.qpos = data.position
|
||||
self.qvel = data.velocity
|
||||
self.effort = data.effort
|
||||
self.data = data
|
||||
if self.is_debug:
|
||||
self.joint_timestamps.append(time.time())
|
||||
|
||||
def puppet_arm_commands_cb(self, data):
|
||||
self.arm_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.arm_command_timestamps.append(time.time())
|
||||
|
||||
def puppet_gripper_commands_cb(self, data):
|
||||
self.gripper_command = data.cmd
|
||||
if self.is_debug:
|
||||
self.gripper_command_timestamps.append(time.time())
|
||||
|
||||
def print_diagnostics(self):
|
||||
def dt_helper(l):
|
||||
l = np.array(l)
|
||||
diff = l[1:] - l[:-1]
|
||||
return np.mean(diff)
|
||||
|
||||
joint_freq = 1 / dt_helper(self.joint_timestamps)
|
||||
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
|
||||
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
|
||||
|
||||
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
|
||||
|
||||
|
||||
def get_arm_joint_positions(bot):
|
||||
return bot.arm.core.joint_states.position[:6]
|
||||
|
||||
|
||||
def get_arm_gripper_positions(bot):
|
||||
return bot.gripper.core.joint_states.position[6]
|
||||
|
||||
|
||||
def move_arms(bot_list, target_pose_list, move_time=1):
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
for t in range(num_steps):
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def move_grippers(bot_list, target_pose_list, move_time):
|
||||
print(f"Moving grippers to {target_pose_list=}")
|
||||
gripper_command = JointSingleCommand(name="gripper")
|
||||
num_steps = int(move_time / constants.DT)
|
||||
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
|
||||
traj_list = [
|
||||
np.linspace(curr_pose, target_pose, num_steps)
|
||||
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
|
||||
]
|
||||
|
||||
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
|
||||
for t in range(num_steps):
|
||||
d = {}
|
||||
for bot_id, bot in enumerate(bot_list):
|
||||
gripper_command.cmd = traj_list[bot_id][t]
|
||||
bot.gripper.core.pub_single.publish(gripper_command)
|
||||
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
|
||||
f.write(json.dumps(d) + "\n")
|
||||
time.sleep(constants.DT)
|
||||
|
||||
|
||||
def setup_puppet_bot(bot):
|
||||
bot.dxl.robot_reboot_motors("single", "gripper", True)
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "position")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_on(bot)
|
||||
|
||||
|
||||
def setup_master_bot(bot):
|
||||
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
|
||||
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
|
||||
torque_off(bot)
|
||||
|
||||
|
||||
def set_standard_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def set_low_pid_gains(bot):
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
|
||||
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
|
||||
|
||||
|
||||
def torque_off(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", False)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", False)
|
||||
|
||||
|
||||
def torque_on(bot):
|
||||
bot.dxl.robot_torque_enable("group", "arm", True)
|
||||
bot.dxl.robot_torque_enable("single", "gripper", True)
|
||||
|
||||
|
||||
# for DAgger
|
||||
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
|
||||
print("\nSyncing!")
|
||||
|
||||
# activate master arms
|
||||
torque_on(master_bot_left)
|
||||
torque_on(master_bot_right)
|
||||
|
||||
# get puppet arm positions
|
||||
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
|
||||
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
|
||||
|
||||
# get puppet gripper positions
|
||||
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
|
||||
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
|
||||
|
||||
# move master arms to puppet positions
|
||||
move_arms(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_qpos, puppet_right_qpos],
|
||||
move_time=1,
|
||||
)
|
||||
|
||||
# move master grippers to puppet positions
|
||||
move_grippers(
|
||||
[master_bot_left, master_bot_right],
|
||||
[puppet_left_gripper, puppet_right_gripper],
|
||||
move_time=1,
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoDisplay(_subscriber.Subscriber):
|
||||
"""Displays video frames."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ax: plt.Axes | None = None
|
||||
self._plt_img: plt.Image | None = None
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
plt.ion()
|
||||
self._ax = plt.subplot()
|
||||
self._plt_img = None
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
assert self._ax is not None
|
||||
|
||||
im = observation["image"][0] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
|
||||
if self._plt_img is None:
|
||||
self._plt_img = self._ax.imshow(im)
|
||||
else:
|
||||
self._plt_img.set_data(im)
|
||||
plt.pause(0.001)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
plt.ioff()
|
||||
plt.close()
|
||||
41
policy/openpi-InternData-A1/examples/aloha_sim/Dockerfile
Normal file
41
policy/openpi-InternData-A1/examples/aloha_sim/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
# Dockerfile for the Aloha simulation environment.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
|
||||
|
||||
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev
|
||||
ENV MUJOCO_GL=egl
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]
|
||||
36
policy/openpi-InternData-A1/examples/aloha_sim/README.md
Normal file
36
policy/openpi-InternData-A1/examples/aloha_sim/README.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# Run Aloha Sim
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.10 examples/aloha_sim/.venv
|
||||
source examples/aloha_sim/.venv/bin/activate
|
||||
uv pip sync examples/aloha_sim/requirements.txt
|
||||
uv pip install -e packages/openpi-client
|
||||
|
||||
# Run the simulation
|
||||
MUJOCO_GL=egl python examples/aloha_sim/main.py
|
||||
```
|
||||
|
||||
Note: If you are seeing EGL errors, you may need to install the following dependencies:
|
||||
|
||||
```bash
|
||||
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env ALOHA_SIM
|
||||
```
|
||||
42
policy/openpi-InternData-A1/examples/aloha_sim/compose.yml
Normal file
42
policy/openpi-InternData-A1/examples/aloha_sim/compose.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/aloha_sim/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: aloha_sim
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/aloha_sim/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
56
policy/openpi-InternData-A1/examples/aloha_sim/env.py
Normal file
56
policy/openpi-InternData-A1/examples/aloha_sim/env.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import gym_aloha # noqa: F401
|
||||
import gymnasium
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class AlohaSimEnvironment(_environment.Environment):
|
||||
"""An environment for an Aloha robot in simulation."""
|
||||
|
||||
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
|
||||
np.random.seed(seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
|
||||
self._gym = gymnasium.make(task, obs_type=obs_type)
|
||||
|
||||
self._last_obs = None
|
||||
self._done = True
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = False
|
||||
self._episode_reward = 0.0
|
||||
|
||||
@override
|
||||
def is_episode_complete(self) -> bool:
|
||||
return self._done
|
||||
|
||||
@override
|
||||
def get_observation(self) -> dict:
|
||||
if self._last_obs is None:
|
||||
raise RuntimeError("Observation is not set. Call reset() first.")
|
||||
|
||||
return self._last_obs # type: ignore
|
||||
|
||||
@override
|
||||
def apply_action(self, action: dict) -> None:
|
||||
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
|
||||
self._last_obs = self._convert_observation(gym_obs) # type: ignore
|
||||
self._done = terminated or truncated
|
||||
self._episode_reward = max(self._episode_reward, reward)
|
||||
|
||||
def _convert_observation(self, gym_obs: dict) -> dict:
|
||||
img = gym_obs["pixels"]["top"]
|
||||
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
|
||||
# Convert axis order from [H, W, C] --> [C, H, W]
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
|
||||
return {
|
||||
"state": gym_obs["agent_pos"],
|
||||
"images": {"cam_high": img},
|
||||
}
|
||||
55
policy/openpi-InternData-A1/examples/aloha_sim/main.py
Normal file
55
policy/openpi-InternData-A1/examples/aloha_sim/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import env as _env
|
||||
from openpi_client import action_chunk_broker
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
from openpi_client.runtime import runtime as _runtime
|
||||
from openpi_client.runtime.agents import policy_agent as _policy_agent
|
||||
import saver as _saver
|
||||
import tyro
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
|
||||
|
||||
task: str = "gym_aloha/AlohaTransferCube-v0"
|
||||
seed: int = 0
|
||||
|
||||
action_horizon: int = 10
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
|
||||
display: bool = False
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
runtime = _runtime.Runtime(
|
||||
environment=_env.AlohaSimEnvironment(
|
||||
task=args.task,
|
||||
seed=args.seed,
|
||||
),
|
||||
agent=_policy_agent.PolicyAgent(
|
||||
policy=action_chunk_broker.ActionChunkBroker(
|
||||
policy=_websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
),
|
||||
action_horizon=args.action_horizon,
|
||||
)
|
||||
),
|
||||
subscribers=[
|
||||
_saver.VideoSaver(args.out_dir),
|
||||
],
|
||||
max_hz=50,
|
||||
)
|
||||
|
||||
runtime.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
tyro.cli(main)
|
||||
@@ -0,0 +1,8 @@
|
||||
gym-aloha
|
||||
imageio
|
||||
matplotlib
|
||||
msgpack
|
||||
numpy>=1.22.4,<2.0.0
|
||||
typing-extensions
|
||||
tyro
|
||||
websockets
|
||||
132
policy/openpi-InternData-A1/examples/aloha_sim/requirements.txt
Normal file
132
policy/openpi-InternData-A1/examples/aloha_sim/requirements.txt
Normal file
@@ -0,0 +1,132 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
|
||||
absl-py==2.1.0
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# labmaze
|
||||
# mujoco
|
||||
certifi==2024.8.30
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cloudpickle==3.1.0
|
||||
# via gymnasium
|
||||
contourpy==1.3.1
|
||||
# via matplotlib
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.8
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
fonttools==4.55.2
|
||||
# via matplotlib
|
||||
glfw==2.8.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
gym-aloha==0.1.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
gymnasium==1.0.0
|
||||
# via gym-aloha
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.36.1
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gym-aloha
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lxml==5.3.0
|
||||
# via dm-control
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.9.3
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
msgpack==1.1.0
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
numpy==1.26.4
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# contourpy
|
||||
# dm-control
|
||||
# dm-env
|
||||
# gymnasium
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# scipy
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==11.0.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
protobuf==5.29.1
|
||||
# via dm-control
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyopengl==3.1.7
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.0
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
requests==2.32.3
|
||||
# via dm-control
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
scipy==1.14.1
|
||||
# via dm-control
|
||||
setuptools==75.6.0
|
||||
# via
|
||||
# dm-control
|
||||
# imageio-ffmpeg
|
||||
# labmaze
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
tqdm==4.67.1
|
||||
# via dm-control
|
||||
typeguard==4.4.1
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# -r examples/aloha_sim/requirements.in
|
||||
# gymnasium
|
||||
# rich
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
websockets==14.1
|
||||
# via -r examples/aloha_sim/requirements.in
|
||||
40
policy/openpi-InternData-A1/examples/aloha_sim/saver.py
Normal file
40
policy/openpi-InternData-A1/examples/aloha_sim/saver.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class VideoSaver(_subscriber.Subscriber):
|
||||
"""Saves episode data."""
|
||||
|
||||
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._out_dir = out_dir
|
||||
self._images: list[np.ndarray] = []
|
||||
self._subsample = subsample
|
||||
|
||||
@override
|
||||
def on_episode_start(self) -> None:
|
||||
self._images = []
|
||||
|
||||
@override
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
im = observation["images"]["cam_high"] # [C, H, W]
|
||||
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
|
||||
self._images.append(im)
|
||||
|
||||
@override
|
||||
def on_episode_end(self) -> None:
|
||||
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
|
||||
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
|
||||
out_path = self._out_dir / f"out_{next_idx}.mp4"
|
||||
|
||||
logging.info(f"Saving video to {out_path}")
|
||||
imageio.mimwrite(
|
||||
out_path,
|
||||
[np.asarray(x) for x in self._images[:: self._subsample]],
|
||||
fps=50 // max(1, self._subsample),
|
||||
)
|
||||
212
policy/openpi-InternData-A1/examples/arx/action_stats.py
Normal file
212
policy/openpi-InternData-A1/examples/arx/action_stats.py
Normal file
@@ -0,0 +1,212 @@
|
||||
from collections import deque
|
||||
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
|
||||
import datasets
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def check_final(
|
||||
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
|
||||
*,
|
||||
# 索引与初始状态
|
||||
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6)
|
||||
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
|
||||
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
|
||||
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
|
||||
|
||||
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
|
||||
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
|
||||
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度(max-min)阈值(度)
|
||||
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2(每步)阈值(度/步)
|
||||
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”(L2,度)
|
||||
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
|
||||
) -> bool:
|
||||
"""
|
||||
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
|
||||
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
|
||||
"""
|
||||
# --- 数据整理为 (N, D) tensor ---
|
||||
if isinstance(last_states, torch.Tensor):
|
||||
states = last_states
|
||||
else:
|
||||
states = torch.as_tensor(list(last_states), dtype=torch.float32)
|
||||
|
||||
if states.ndim != 2:
|
||||
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
|
||||
N, D = states.shape
|
||||
if D < arm_dofs:
|
||||
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
|
||||
if N < 2:
|
||||
return False # 样本太少,无法判定稳定
|
||||
|
||||
# 取最近窗口
|
||||
w = min(N, stability_window)
|
||||
window = states[-w:] # (w, D)
|
||||
arm = window[:, :arm_dofs] # (w, 6)
|
||||
last_arm = arm[-1] # (6,)
|
||||
last_gripper = float(window[-1, gripper_index])
|
||||
|
||||
# --- 1) 最近 w 帧“没有太大变化” ---
|
||||
# 两个指标:每关节range(max-min)要小、相邻帧的平均“速度”要小
|
||||
deg2rad = torch.pi / 180.0
|
||||
range_tol = per_joint_range_deg * deg2rad
|
||||
speed_tol = mean_speed_deg * deg2rad
|
||||
|
||||
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
|
||||
max_range = float(ranges.abs().max()) # 标量
|
||||
diffs = arm[1:] - arm[:-1] # (w-1, 6)
|
||||
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
|
||||
|
||||
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
|
||||
|
||||
# --- 2) 夹爪关闭 ---
|
||||
gripper_closed = (last_gripper < gripper_closed_thresh)
|
||||
|
||||
# --- 3) 末帧与“初始”差距要大 ---
|
||||
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
|
||||
if init.numel() != arm_dofs:
|
||||
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
|
||||
dist_from_init = float(torch.linalg.norm(last_arm - init))
|
||||
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
|
||||
|
||||
# 组合判定
|
||||
return bool(stable and gripper_closed and far_from_init)
|
||||
# return bool(gripper_closed and far_from_init)
|
||||
|
||||
|
||||
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
|
||||
"""
|
||||
Quickly fetch the last frame of each episode in a LeRobotDataset.
|
||||
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
|
||||
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
|
||||
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
|
||||
Returns: list[dict], where each element contains the last frame info of one episode.
|
||||
"""
|
||||
# 1) Compute the global index of the last row for each episode.
|
||||
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
|
||||
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
|
||||
|
||||
# 2) Determine which columns to load.
|
||||
# By default, exclude video/image columns to avoid triggering slow video decoding.
|
||||
if keys is None:
|
||||
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
|
||||
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
|
||||
|
||||
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
|
||||
last_rows = ds.hf_dataset.select(end_idxs)
|
||||
|
||||
# 4) Build a dictionary of tensors for each requested key.
|
||||
out = []
|
||||
col = {k: last_rows[k] for k in keys}
|
||||
|
||||
# Convert lists of tensors into stacked tensors for easier indexing.
|
||||
for k, v in col.items():
|
||||
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
|
||||
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
|
||||
col[k] = torch.stack(v[:])
|
||||
|
||||
# Iterate through each episode’s last frame and build a dict with its values.
|
||||
for i, ep_end in enumerate(end_idxs):
|
||||
item = {}
|
||||
for k in keys:
|
||||
val = col[k][i]
|
||||
# Unpack 0-dimensional tensors into Python scalars.
|
||||
if torch.is_tensor(val) and val.ndim == 0:
|
||||
val = val.item()
|
||||
item[k] = val
|
||||
|
||||
# Map task_index back to the human-readable task string.
|
||||
if "task_index" in item:
|
||||
item["task"] = ds.meta.tasks[int(item["task_index"])]
|
||||
out.append(item)
|
||||
|
||||
# 5) Optionally decode the actual image/video frame for each last timestamp.
|
||||
if include_images and len(ds.meta.video_keys) > 0:
|
||||
for i, ep_end in enumerate(end_idxs):
|
||||
ep_idx = int(out[i]["episode_index"])
|
||||
ts = float(out[i]["timestamp"])
|
||||
# Prepare a query dictionary: one timestamp per camera key.
|
||||
query_ts = {k: [ts] for k in ds.meta.video_keys}
|
||||
# Decode video frames at the specified timestamps for this episode.
|
||||
frames = ds._query_videos(query_ts, ep_idx)
|
||||
# Attach the decoded frame tensors to the output dictionary.
|
||||
for k, v in frames.items():
|
||||
out[i][k] = v
|
||||
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize your dataset (replace with your repo ID or local path).
|
||||
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
|
||||
|
||||
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
|
||||
last_infos = get_last_frames(ds, include_images=False)
|
||||
|
||||
# Stack all 'observation.state' vectors into a single tensor for further processing.
|
||||
states = torch.stack([info['observation.state'] for info in last_infos])
|
||||
# Extract the left-arm joint states (first 7 values of each state vector).
|
||||
left_arm_states = states[:, 0:7]
|
||||
mean_state = torch.mean(left_arm_states, dim=0)
|
||||
std_state = torch.std(left_arm_states, dim=0)
|
||||
|
||||
# Print the collected metadata for verification.
|
||||
print(last_infos)
|
||||
|
||||
# --- Run check_final per episode using the last <=50 states ---
|
||||
|
||||
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
|
||||
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
|
||||
STABILITY_WINDOW = 120 # must be consistent with check_final's default
|
||||
|
||||
# Determine which episodes to iterate
|
||||
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
|
||||
|
||||
episode_flags = {}
|
||||
num_true, num_false = 0, 0
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
# Global index range [from_idx, to_idx) for this episode
|
||||
from_idx = int(ds.episode_data_index["from"][ep_idx])
|
||||
to_idx = int(ds.episode_data_index["to"][ep_idx])
|
||||
|
||||
if to_idx - from_idx <= 0:
|
||||
episode_flags[ep_idx] = False
|
||||
num_false += 1
|
||||
continue
|
||||
|
||||
# Take the last <= STABILITY_WINDOW frames from this episode
|
||||
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
|
||||
rows = ds.hf_dataset.select(idxs)
|
||||
|
||||
# Collect full "observation.state" (shape ~ [W, S])
|
||||
s_col = rows["observation.state"]
|
||||
if isinstance(s_col, datasets.arrow_dataset.Column):
|
||||
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
|
||||
else:
|
||||
S = torch.stack(s_col) # already a list[tensor]
|
||||
|
||||
# Build the 7D small state per frame: first 6 joints + gripper
|
||||
# (Assumes the gripper signal is at the last position of the full state vector)
|
||||
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
|
||||
|
||||
# Run your stopping logic
|
||||
ok = check_final(
|
||||
small_states,
|
||||
arm_dofs=EP_ARM_DOFS,
|
||||
gripper_index=-1,
|
||||
stability_window=STABILITY_WINDOW,
|
||||
)
|
||||
episode_flags[ep_idx] = bool(ok)
|
||||
num_true += int(ok)
|
||||
num_false += int(not ok)
|
||||
|
||||
# Summary
|
||||
total_eps = len(episode_indices)
|
||||
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
|
||||
|
||||
# List some failed episodes for quick inspection
|
||||
failed_eps = [e for e, passed in episode_flags.items() if not passed]
|
||||
print("Failed episode indices (first 20):", failed_eps[:20])
|
||||
|
||||
88
policy/openpi-InternData-A1/examples/arx/extract_frame.py
Normal file
88
policy/openpi-InternData-A1/examples/arx/extract_frame.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
def extract_last_frame_from_videos(root_dir, output_dir, xx_last_frame=1):
|
||||
"""
|
||||
遍历目录,找到所有images.rgb.hand_right视频文件,提取最后一帧并保存
|
||||
"""
|
||||
# 查找所有mp4视频文件
|
||||
video_files = []
|
||||
for root, dirs, files in os.walk(root_dir):
|
||||
for file in files:
|
||||
|
||||
if file.endswith('.mp4') and 'observation/head' in root:
|
||||
video_files.append(os.path.join(root, file))
|
||||
|
||||
print(f"找到 {len(video_files)} 个视频文件")
|
||||
|
||||
# 处理每个视频文件
|
||||
for video_path in tqdm(video_files):
|
||||
try:
|
||||
# 提取set名称和episode名称
|
||||
path_parts = Path(video_path).parts
|
||||
set_name = None
|
||||
episode_name = None
|
||||
for part in path_parts:
|
||||
if part.startswith('set'):
|
||||
set_name = part
|
||||
if part.startswith('000'):
|
||||
episode_name = part.replace('.mp4', '')
|
||||
|
||||
if not set_name or not episode_name:
|
||||
print(f"无法从路径中提取set和episode信息: {video_path}")
|
||||
continue
|
||||
|
||||
# 生成输出文件名
|
||||
output_filename = f"{set_name}_{episode_name}.jpg"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# 打开视频文件
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
|
||||
if not cap.isOpened():
|
||||
print(f"无法打开视频: {video_path}")
|
||||
continue
|
||||
|
||||
# 获取总帧数
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
|
||||
if total_frames == 0:
|
||||
print(f"视频没有帧: {video_path}")
|
||||
cap.release()
|
||||
continue
|
||||
|
||||
# 跳转到最后一帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - xx_last_frame)
|
||||
ret, frame = cap.read()
|
||||
|
||||
if ret:
|
||||
# 保存最后一帧
|
||||
cv2.imwrite(output_path, frame)
|
||||
print(f"已保存:\n {output_path}")
|
||||
else:
|
||||
print(f"无法读取最后一帧: {video_path}")
|
||||
|
||||
# 释放资源
|
||||
cap.release()
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理视频时出错 {video_path}: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 指定要遍历的根目录
|
||||
root_directory = "/home/caijunhao/h-ceph/InternData-A1-raw/arx_lift2/Pick_the_industrial_components_from_the_conveyor" # 当前目录,您可以修改为您的目录路径
|
||||
output_path = 'data/Pick_the_industrial_components_from_the_conveyor/'
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
sub_list = os.listdir(root_directory)
|
||||
exclude_list = []
|
||||
# exclude_list = [f"{i}" for i in range(16)] + [f"{i}" for i in range(26, 29)]
|
||||
xx_last_frame = 1
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
for sub in tqdm(sub_list):
|
||||
if sub.split('-')[1].split('_')[0] in exclude_list:
|
||||
continue
|
||||
# print("os.path.join([root_directory, sub])\n", os.path.join(root_directory, sub))
|
||||
extract_last_frame_from_videos(os.path.join(root_directory, sub), output_path, xx_last_frame=xx_last_frame)
|
||||
print("处理完成!")
|
||||
670
policy/openpi-InternData-A1/examples/arx/lmdb2lerobot_arx.py
Normal file
670
policy/openpi-InternData-A1/examples/arx/lmdb2lerobot_arx.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# source /fs-computility/efm/liyang/miniconda3/etc/profile.d/conda.sh
|
||||
# conda activate act
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import gc
|
||||
import shutil
|
||||
from concurrent.futures import ALL_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
import torchvision
|
||||
import cv2
|
||||
import h5py
|
||||
import lmdb
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
from PIL import Image
|
||||
from scipy.spatial.transform import Rotation
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
import pdb
|
||||
import os
|
||||
import imageio # imageio-ffmpeg
|
||||
from lerobot.common.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.utils import check_timestamps_sync, get_episode_data_index, validate_episode_buffer
|
||||
import time
|
||||
# import ray
|
||||
# from ray.runtime_env import RuntimeEnv
|
||||
|
||||
"""
|
||||
Store both camera image and robot state as a combined observation.
|
||||
Args:
|
||||
observation: images(camera), states (robot state)
|
||||
actions: joint, gripper, ee_pose
|
||||
"""
|
||||
FEATURES = {
|
||||
"images.rgb.head": {
|
||||
"dtype": "video",
|
||||
"shape": (368, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"images.rgb.hand_left": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"images.rgb.hand_right": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
# "states.left_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
|
||||
# },
|
||||
# "states.left_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["left_gripper_0",],
|
||||
# },
|
||||
# "states.right_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
|
||||
# },
|
||||
# "states.right_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["right_gripper_0",],
|
||||
# },
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
|
||||
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (14,),
|
||||
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
|
||||
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
|
||||
},
|
||||
# "actions.left_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
|
||||
# },
|
||||
# "actions.left_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["left_gripper_0",],
|
||||
# },
|
||||
# "actions.right_joint.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (6,),
|
||||
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
|
||||
# },
|
||||
# "actions.right_gripper.position": {
|
||||
# "dtype": "float32",
|
||||
# "shape": (1,),
|
||||
# "names": ["right_gripper_0", ],
|
||||
# },
|
||||
|
||||
}
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
def filter_forbidden_frames(state_dict, position_threshold=0.001, velocity_threshold=0.005):
|
||||
"""
|
||||
过滤禁止的帧,基于位置和速度阈值
|
||||
|
||||
参数:
|
||||
- state_dict: 形状为 (n, 14) 的状态数组
|
||||
- position_threshold: 位置变化的阈值
|
||||
- velocity_threshold: 速度变化的阈值
|
||||
|
||||
返回:
|
||||
- valid_mask: 布尔数组,True表示有效帧
|
||||
"""
|
||||
# 排除夹爪列(第6和第13列,索引从0开始)
|
||||
qpos_columns = [i for i in range(14)]
|
||||
qpos_data = state_dict[:, qpos_columns]
|
||||
|
||||
n_frames = len(state_dict)
|
||||
valid_mask = np.ones(n_frames, dtype=bool)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# 计算帧间差异(速度)
|
||||
if n_frames > 1:
|
||||
|
||||
diff_sum = np.sum(np.abs(np.diff(qpos_data, axis=0)), axis=1)
|
||||
# sorted_indices = np.argsort(diff_sum)[::-1]
|
||||
# sorted_abs_sums = diff_sum[sorted_indices]
|
||||
|
||||
# velocities = np.diff(qpos_data, axis=0)
|
||||
# 检查速度是否超过阈值
|
||||
for i in range(n_frames - 1):
|
||||
if np.any(np.abs(diff_sum[i]) > position_threshold):
|
||||
valid_mask[i] = True # 有运动,有效帧
|
||||
else:
|
||||
valid_mask[i] = False # 静止,可能是禁止帧
|
||||
valid_mask[i] = True
|
||||
return valid_mask
|
||||
|
||||
def statistical_filter(state_dict, std_multiplier=2.0):
|
||||
"""
|
||||
使用统计方法检测异常(禁止)帧
|
||||
"""
|
||||
# 排除夹爪列
|
||||
qpos_columns = [i for i in range(14) if i not in [6, 13]]
|
||||
qpos_data = state_dict[:, qpos_columns]
|
||||
|
||||
# 计算每列的均值和标准差
|
||||
means = np.mean(qpos_data, axis=0)
|
||||
stds = np.std(qpos_data, axis=0)
|
||||
|
||||
# 创建有效掩码
|
||||
valid_mask = np.ones(len(state_dict), dtype=bool)
|
||||
|
||||
for i in range(len(state_dict)):
|
||||
# 检查每个关节位置是否在合理范围内
|
||||
deviations = np.abs(qpos_data[i] - means)
|
||||
if np.any(deviations > std_multiplier * stds):
|
||||
valid_mask[i] = False # 异常帧
|
||||
|
||||
return valid_mask
|
||||
|
||||
|
||||
class ARXDataset(LeRobotDataset):
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | Path | None = None,
|
||||
episodes: list[int] | None = None,
|
||||
image_transforms: Callable | None = None,
|
||||
delta_timestamps: dict[list[float]] | None = None,
|
||||
tolerance_s: float = 1e-4,
|
||||
download_videos: bool = True,
|
||||
local_files_only: bool = False,
|
||||
video_backend: str | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
episodes=episodes,
|
||||
image_transforms=image_transforms,
|
||||
delta_timestamps=delta_timestamps,
|
||||
tolerance_s=tolerance_s,
|
||||
download_videos=download_videos,
|
||||
local_files_only=local_files_only,
|
||||
video_backend=video_backend,
|
||||
)
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
|
||||
if not episode_data:
|
||||
episode_buffer = self.episode_buffer
|
||||
|
||||
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
||||
episode_length = episode_buffer.pop("size")
|
||||
tasks = episode_buffer.pop("task")
|
||||
episode_tasks = list(set(tasks))
|
||||
episode_index = episode_buffer["episode_index"]
|
||||
|
||||
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
|
||||
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
||||
|
||||
for task in episode_tasks:
|
||||
task_index = self.meta.get_task_index(task)
|
||||
if task_index is None:
|
||||
self.meta.add_task(task)
|
||||
|
||||
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
||||
for key, ft in self.features.items():
|
||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
|
||||
continue
|
||||
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
|
||||
for key in self.meta.video_keys:
|
||||
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
|
||||
episode_buffer[key] = str(video_path) # PosixPath -> str
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(videos[key], video_path)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
self._save_episode_table(episode_buffer, episode_index)
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
||||
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
||||
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
||||
check_timestamps_sync(
|
||||
episode_buffer["timestamp"],
|
||||
episode_buffer["episode_index"],
|
||||
ep_data_index_np,
|
||||
self.fps,
|
||||
self.tolerance_s,
|
||||
)
|
||||
if not episode_data:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
for name in frame:
|
||||
if isinstance(frame[name], torch.Tensor):
|
||||
frame[name] = frame[name].numpy()
|
||||
features = {key: value for key, value in self.features.items() if key in self.hf_features}
|
||||
if self.episode_buffer is None:
|
||||
self.episode_buffer = self.create_episode_buffer()
|
||||
frame_index = self.episode_buffer["size"]
|
||||
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
||||
self.episode_buffer["frame_index"].append(frame_index)
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
|
||||
for key in frame:
|
||||
if key == "task":
|
||||
self.episode_buffer["task"].append(frame["task"])
|
||||
continue
|
||||
if key not in self.features:
|
||||
print("key ", key)
|
||||
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
# def crop_resize_no_padding(image, target_size=(480, 640)):
|
||||
# """
|
||||
# Crop and scale to target size (no padding)
|
||||
# :param image: input image (NumPy array)
|
||||
# :param target_size: target size (height, width)
|
||||
# :return: processed image
|
||||
# """
|
||||
# h, w = image.shape[:2]
|
||||
# target_h, target_w = target_size
|
||||
# target_ratio = target_w / target_h # Target aspect ratio (e.g. 640/480=1.333)
|
||||
|
||||
# # the original image aspect ratio and cropping direction
|
||||
# if w / h > target_ratio: # Original image is wider → crop width
|
||||
# crop_w = int(h * target_ratio) # Calculate crop width based on target aspect ratio
|
||||
# crop_h = h
|
||||
# start_x = (w - crop_w) // 2 # Horizontal center starting point
|
||||
# start_y = 0
|
||||
# else: # Original image is higher → crop height
|
||||
# crop_h = int(w / target_ratio) # Calculate clipping height according to target aspect ratio
|
||||
# crop_w = w
|
||||
# start_x = 0
|
||||
# start_y = (h - crop_h) // 2 # Vertical center starting point
|
||||
|
||||
# # Perform centered cropping (to prevent out-of-bounds)
|
||||
# start_x, start_y = max(0, start_x), max(0, start_y)
|
||||
# end_x, end_y = min(w, start_x + crop_w), min(h, start_y + crop_h)
|
||||
# cropped = image[start_y:end_y, start_x:end_x]
|
||||
|
||||
# # Resize to target size (bilinear interpolation)
|
||||
# resized = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
|
||||
# return resized
|
||||
|
||||
|
||||
def load_lmdb_data(episode_path: Path, sava_path: Path, fps_factor: int, target_fps: int) -> Optional[Dict]:
|
||||
def load_image(txn, key):
|
||||
raw = txn.get(key)
|
||||
data = pickle.loads(raw)
|
||||
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
# Convert to RGB if necessary
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
# image = crop_resize_no_padding(image, target_size=(480, 640))
|
||||
return image
|
||||
try:
|
||||
env = lmdb.open(
|
||||
str(episode_path / "lmdb"),
|
||||
readonly=True,
|
||||
lock=False,
|
||||
max_readers=128,
|
||||
readahead=False
|
||||
)
|
||||
with env.begin(write=False) as txn:
|
||||
keys = [k for k, _ in txn.cursor()]
|
||||
|
||||
image_keys = sorted([k for k in keys if b'head' in k])
|
||||
if not image_keys:
|
||||
return None
|
||||
|
||||
all_qpos = pickle.loads(txn.get(b'/observations/qpos'))
|
||||
|
||||
if np.isscalar(all_qpos):
|
||||
total_steps = len(image_keys)
|
||||
all_qpos = [all_qpos] * total_steps
|
||||
else:
|
||||
total_steps = len(all_qpos)
|
||||
all_qpos = np.stack(all_qpos)
|
||||
state_action_dict = {}
|
||||
state_action_dict["states.left_joint.position"] = all_qpos[:, :6]
|
||||
state_action_dict["states.left_gripper.position"] = all_qpos[:, 6][:, None] # np.expand_dims(all_qpos[:, 6], axis=1)
|
||||
state_action_dict["states.right_joint.position"] = all_qpos[:, 7:13]
|
||||
state_action_dict["states.right_gripper.position"] = all_qpos[:, 13][:, None] # np.expand_dims(all_qpos[:, 13], axis=1)
|
||||
# state_keys = list(state_action_dict.keys())
|
||||
# for k in state_keys:
|
||||
# state_action_dict[k.replace("states", "actions")] = np.concatenate([state_action_dict[k][1:, :], state_action_dict[k][-1, :][None,:]], axis=0)
|
||||
|
||||
|
||||
# action_dict = {}
|
||||
# action_dict["actions.left_joint.position"] = np.concatenate([state_dict["states.left_joint.position"][1:, :], state_dict["states.left_joint.position"][-1, :][None,:]], axis=0)
|
||||
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
|
||||
# action_dict["actions.right_joint.position"] = state_dict["states.right_joint.position"][1:, :]
|
||||
# action_dict["actions.right_gripper.position"] = state_dict["states.right_gripper.position"][1:, :]
|
||||
|
||||
action_dict = {}
|
||||
|
||||
action_dict["action"] = np.concatenate([all_qpos[1:,], all_qpos[-1,].reshape(-1, 14)], axis=0)
|
||||
state_dict = {}
|
||||
state_dict["observation.state"] = all_qpos
|
||||
mask1 = filter_forbidden_frames(state_dict["observation.state"])
|
||||
# state_dict["observation.state"] = state_dict["observation.state"][mask1]
|
||||
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
|
||||
# action_dict["actions.right_arm.position"] = np.concatenate([state_action_dict["states.right_joint.position"][1:, :], state_action_dict["states.right_joint.position"][-1, :][None,:]], axis=0)
|
||||
# action_dict["actions.left_arm.position"] = state_dict["states.right_gripper.position"][1:, :]
|
||||
|
||||
assert total_steps == len(image_keys), "qpos length mismatch"
|
||||
selected_steps = [step for step in range(total_steps) if step % fps_factor == 0 and mask1[step]]
|
||||
frames = []
|
||||
image_observations = {
|
||||
"images.rgb.head": [],
|
||||
"images.rgb.hand_left": [],
|
||||
"images.rgb.hand_right": []
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for step_index, step in enumerate(selected_steps):
|
||||
step_str = f"{step:04d}"
|
||||
head_key = f"observation/head/color_image/{step_str}".encode()
|
||||
left_key = f"observation/left_wrist/color_image/{step_str}".encode()
|
||||
right_key = f"observation/right_wrist/color_image/{step_str}".encode()
|
||||
if not (head_key in keys and left_key in keys and right_key in keys):
|
||||
continue
|
||||
# state = all_qpos[step]
|
||||
# if step_index < len(selected_steps) - 1:
|
||||
# action = all_qpos[selected_steps[step_index + 1]]
|
||||
# else:
|
||||
# action = state
|
||||
data_dict = {}
|
||||
# for key, value in state_action_dict.items():
|
||||
# data_dict[key] = value[step]
|
||||
data_dict['action'] = action_dict["action"][step]
|
||||
data_dict["task"] = " ".join(episode_path.parent.parent.name.split("_"))
|
||||
data_dict['observation.state'] = state_dict["observation.state"][step]
|
||||
# frames.append({
|
||||
# "observation.states.joint.position": state,
|
||||
# "actions.joint.position": action,
|
||||
# "task": task_name,
|
||||
# })
|
||||
frames.append(data_dict)
|
||||
image_observations["images.rgb.head"].append(load_image(txn, head_key))
|
||||
image_observations["images.rgb.hand_left"].append(load_image(txn, left_key))
|
||||
image_observations["images.rgb.hand_right"].append(load_image(txn, right_key))
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"load image_observations of {episode_path}")
|
||||
env.close()
|
||||
if not frames:
|
||||
return None
|
||||
os.makedirs(sava_path, exist_ok=True)
|
||||
os.makedirs(sava_path/episode_path.name, exist_ok=True)
|
||||
imageio.mimsave(sava_path/episode_path.name/'head.mp4', image_observations["images.rgb.head"], fps=target_fps)
|
||||
imageio.mimsave(sava_path/episode_path.name/'hand_left.mp4', image_observations["images.rgb.hand_left"], fps=target_fps)
|
||||
imageio.mimsave(sava_path/episode_path.name/'hand_right.mp4', image_observations["images.rgb.hand_right"], fps=target_fps)
|
||||
print(f"imageio.mimsave time taken of {episode_path}")
|
||||
|
||||
return {
|
||||
"frames": frames,
|
||||
"videos": {
|
||||
"images.rgb.head": sava_path/episode_path.name/"head.mp4",
|
||||
"images.rgb.hand_left": sava_path/episode_path.name/"hand_left.mp4",
|
||||
"images.rgb.hand_right": sava_path/episode_path.name/"hand_right.mp4",
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load LMDB data: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_all_tasks(src_path: Path, output_path: Path) -> Tuple[Path, Path]:
|
||||
src_dirs = sorted(list(src_path.glob("*"))) # "set*-*_collector*_datatime" as the conversion unit
|
||||
|
||||
save_dirs = [output_path/_dir.parent.name/_dir.name for _dir in src_dirs]
|
||||
tasks_tuples = zip(src_dirs, save_dirs)
|
||||
for task in tasks_tuples:
|
||||
yield task
|
||||
|
||||
def compute_episode_stats(episode_data: Dict[str, List[str] | np.ndarray], features: Dict) -> Dict:
|
||||
ep_stats = {}
|
||||
for key, data in episode_data.items():
|
||||
if features[key]["dtype"] == "string":
|
||||
continue
|
||||
elif features[key]["dtype"] in ["image", "video"]:
|
||||
ep_ft_array = sample_images(data)
|
||||
axes_to_reduce = (0, 2, 3) # keep channel dim
|
||||
keepdims = True
|
||||
else:
|
||||
ep_ft_array = data # data is already a np.ndarray
|
||||
axes_to_reduce = 0 # compute stats over the first axis
|
||||
keepdims = data.ndim == 1 # keep as np.array
|
||||
|
||||
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
|
||||
if features[key]["dtype"] in ["image", "video"]:
|
||||
ep_stats[key] = {
|
||||
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
|
||||
}
|
||||
return ep_stats
|
||||
|
||||
def sample_images(input):
|
||||
if type(input) is str:
|
||||
video_path = input
|
||||
reader = torchvision.io.VideoReader(video_path, stream="video")
|
||||
frames = [frame["data"] for frame in reader]
|
||||
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
|
||||
sampled_indices = sample_indices(len(frames_array))
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
img = frames_array[idx]
|
||||
img = auto_downsample_height_width(img)
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
images[i] = img
|
||||
elif type(input) is np.ndarray:
|
||||
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
|
||||
sampled_indices = sample_indices(len(frames_array))
|
||||
images = None
|
||||
for i, idx in enumerate(sampled_indices):
|
||||
img = frames_array[idx]
|
||||
img = auto_downsample_height_width(img)
|
||||
if images is None:
|
||||
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
|
||||
images[i] = img
|
||||
return images
|
||||
|
||||
|
||||
def load_local_dataset(episode_path: str, save_path:str, origin_fps=30, target_fps=30):
|
||||
fps_factor = origin_fps // target_fps
|
||||
# print(f"fps downsample factor: {fps_factor}")
|
||||
# logging.info(f"fps downsample factor: {fps_factor}")
|
||||
# for format_str in [f"{episode_id:07d}", f"{episode_id:06d}", str(episode_id)]:
|
||||
# episode_path = Path(src_path) / format_str
|
||||
# save_path = Path(save_path) / format_str
|
||||
# if episode_path.exists():
|
||||
# break
|
||||
# else:
|
||||
# logging.warning(f"Episode directory not found for ID {episode_id}")
|
||||
# return None, None
|
||||
episode_path = Path(episode_path)
|
||||
if not episode_path.exists():
|
||||
logging.warning(f"{episode_path} does not exist")
|
||||
return None, None
|
||||
|
||||
if not (episode_path / "lmdb/data.mdb").exists():
|
||||
logging.warning(f"LMDB data not found for episode {episode_path}")
|
||||
return None, None
|
||||
|
||||
raw_dataset = load_lmdb_data(episode_path, save_path, fps_factor, target_fps)
|
||||
if raw_dataset is None:
|
||||
return None, None
|
||||
frames = raw_dataset["frames"] # states, actions, task
|
||||
|
||||
videos = raw_dataset["videos"] # image paths
|
||||
## check the frames
|
||||
for camera_name, video_path in videos.items():
|
||||
if not os.path.exists(video_path):
|
||||
logging.error(f"Video file {video_path} does not exist.")
|
||||
print(f"Camera {camera_name} Video file {video_path} does not exist.")
|
||||
return None, None
|
||||
return frames, videos
|
||||
|
||||
|
||||
def save_as_lerobot_dataset(task: tuple[Path, Path], repo_id, num_threads, debug, origin_fps=30, target_fps=30, robot_type="piper", delete_downsampled_videos=True):
|
||||
src_path, save_path = task
|
||||
print(f"**Processing collected** {src_path}")
|
||||
print(f"**saving to** {save_path}")
|
||||
if save_path.exists():
|
||||
# print(f"Output directory {save_path} already exists. Deleting it.")
|
||||
# logging.warning(f"Output directory {save_path} already exists. Deleting it.")
|
||||
# shutil.rmtree(save_path)
|
||||
print(f"Output directory {save_path} already exists.")
|
||||
return
|
||||
|
||||
dataset = ARXDataset.create(
|
||||
repo_id=f"{repo_id}",
|
||||
root=save_path,
|
||||
fps=target_fps,
|
||||
robot_type=robot_type,
|
||||
features=FEATURES,
|
||||
)
|
||||
all_episode_paths = sorted([f.as_posix() for f in src_path.glob(f"*") if f.is_dir()])
|
||||
# all_subdir_eids = [int(Path(path).name) for path in all_subdir]
|
||||
if debug:
|
||||
for i in range(1):
|
||||
# pdb.set_trace()
|
||||
frames, videos = load_local_dataset(episode_path=all_episode_paths[i], save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.save_episode(videos=videos)
|
||||
if delete_downsampled_videos:
|
||||
for _, video_path in videos.items():
|
||||
parent_dir = os.path.dirname(video_path)
|
||||
try:
|
||||
shutil.rmtree(parent_dir)
|
||||
# os.remove(video_path)
|
||||
# print(f"Successfully deleted: {parent_dir}")
|
||||
print(f"Successfully deleted: {video_path}")
|
||||
except Exception as e:
|
||||
pass # Handle the case where the directory might not exist or is already deleted
|
||||
else:
|
||||
for batch_index in range(len(all_episode_paths)//num_threads+1):
|
||||
batch_episode_paths = all_episode_paths[batch_index*num_threads:(batch_index+1)*num_threads]
|
||||
if len(batch_episode_paths) == 0:
|
||||
continue
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = []
|
||||
for episode_path in batch_episode_paths:
|
||||
print("starting to process episode: ", episode_path)
|
||||
futures.append(
|
||||
executor.submit(load_local_dataset, episode_path=episode_path, save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
|
||||
)
|
||||
for raw_dataset in as_completed(futures):
|
||||
frames, videos = raw_dataset.result()
|
||||
if frames is None or videos is None:
|
||||
print(f"Skipping episode {episode_path} due to missing data.")
|
||||
continue
|
||||
for frame_data in frames:
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.save_episode(videos=videos)
|
||||
gc.collect()
|
||||
print(f"finishing processed {videos}")
|
||||
if delete_downsampled_videos:
|
||||
for _, video_path in videos.items():
|
||||
# Get the parent directory of the video
|
||||
parent_dir = os.path.dirname(video_path)
|
||||
try:
|
||||
shutil.rmtree(parent_dir)
|
||||
print(f"Successfully deleted: {parent_dir}")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
def main(src_path, save_path, repo_id, num_threads=60, debug=False, origin_fps=30, target_fps=30):
|
||||
logging.info("Scanning for episodes...")
|
||||
tasks = get_all_tasks(src_path, save_path)
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if debug:
|
||||
task = next(tasks)
|
||||
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
|
||||
else:
|
||||
for task in tasks:
|
||||
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert collected data from Piper to Lerobot format.")
|
||||
parser.add_argument(
|
||||
"--src_path",
|
||||
type=str,
|
||||
# required=False,
|
||||
default="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/",
|
||||
help="Path to the input file containing collected data in Piper format.",
|
||||
#help="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Make_a_beef_sandwich",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_path",
|
||||
type=str,
|
||||
# required=False,
|
||||
default="/fs-computility/efm/shared/datasets/myData-A1/real/lerobot_v2_1/agilex_split_aloha/",
|
||||
help="Path to the output file where the converted Lerobot format will be saved.",
|
||||
#help="Path to the output file where the converted Lerobot format will be saved.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Run in debug mode with limited episodes",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of threads per process",
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--task_name",
|
||||
# type=str,
|
||||
# required=True,
|
||||
# default="Pick_up_the_marker_and_put_it_into_the_pen_holder",
|
||||
# help="Name of the task to be processed. Default is 'Pick_up_the_marker_and_put_it_into_the_pen_holder'.",
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--repo_id",
|
||||
type=str,
|
||||
required=True,
|
||||
# default="SplitAloha_20250714",
|
||||
help="identifier for the dataset repository.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--origin_fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames per second for the obervation video. Default is 30.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--target_fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames per second for the downsample video. Default is 30.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
assert int(args.origin_fps) % int(args.target_fps) == 0, "origin_fps must be an integer multiple of target_fps"
|
||||
start_time = time.time()
|
||||
main(
|
||||
src_path=Path(args.src_path),
|
||||
save_path=Path(args.save_path),
|
||||
repo_id=args.repo_id,
|
||||
num_threads=args.num_threads,
|
||||
debug=args.debug,
|
||||
origin_fps=args.origin_fps,
|
||||
target_fps=args.target_fps
|
||||
)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
print(f"Total time taken: {elapsed_time:.2f} seconds")
|
||||
# --target_fps 10
|
||||
# --src_path /fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Put_the_bananas_in_the_basket
|
||||
# --save_path /mnt/shared-storage-user/internvla/Users/liyang/data/processed_data/arx_lift2
|
||||
1693
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data.py
Normal file
1693
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data.py
Normal file
File diff suppressed because it is too large
Load Diff
1509
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data_v2.py
Normal file
1509
policy/openpi-InternData-A1/examples/arx/merge_lerobot_data_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,587 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
|
||||
|
||||
This script loads a JAX model checkpoint using orbax and can either:
|
||||
1. Print out all the parameter keys in a hierarchical structure for inspection
|
||||
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
|
||||
|
||||
Usage:
|
||||
# Just inspect keys:
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
|
||||
|
||||
# Convert to PyTorch:
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
|
||||
|
||||
Example:
|
||||
# pi0_droid
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
|
||||
|
||||
# pi0_aloha_sim
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
|
||||
|
||||
# pi05_droid
|
||||
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
from typing import Literal
|
||||
|
||||
from flax.nnx import traversals
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import safetensors
|
||||
import torch
|
||||
import tyro
|
||||
|
||||
import openpi.models.gemma
|
||||
import openpi.models.model
|
||||
import openpi.models.pi0_config
|
||||
import openpi.models_pytorch.pi0_pytorch
|
||||
from openpi.training import utils
|
||||
import openpi.training.config as _config
|
||||
|
||||
|
||||
def slice_paligemma_state_dict(state_dict, config):
|
||||
"""Convert PaliGemma JAX parameters to PyTorch format."""
|
||||
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
|
||||
|
||||
# patch embeddings
|
||||
jax_key = f"img/embedding/kernel{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
|
||||
|
||||
jax_key = f"img/embedding/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# positional embeddings
|
||||
jax_key = f"img/pos_embedding{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
|
||||
|
||||
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
|
||||
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
|
||||
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
|
||||
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
|
||||
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
|
||||
|
||||
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
|
||||
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
|
||||
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
|
||||
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
|
||||
|
||||
encoderblock_attention_0_key_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_key_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_value_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_value_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_query_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_query_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_out_kernel = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
|
||||
)
|
||||
encoderblock_attention_0_out_bias = state_dict.pop(
|
||||
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
|
||||
)
|
||||
|
||||
for i in range(config.vision_config.num_hidden_layers):
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
|
||||
] = encoderblock_layernorm0_scale[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
|
||||
] = encoderblock_layernorm0_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
|
||||
] = encoderblock_layernorm1_scale[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
|
||||
] = encoderblock_layernorm1_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
|
||||
] = encoderblock_mlp_dense0_kernel[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
|
||||
] = encoderblock_mlp_dense0_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
|
||||
] = encoderblock_mlp_dense1_kernel[i].transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
|
||||
] = encoderblock_mlp_dense1_bias[i]
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
|
||||
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
|
||||
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
|
||||
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
|
||||
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
|
||||
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
|
||||
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
|
||||
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
|
||||
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
|
||||
|
||||
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
||||
|
||||
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# multimodal projector
|
||||
jax_key = f"img/head/kernel{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
|
||||
|
||||
jax_key = f"img/head/bias{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# text decoder (gemma)
|
||||
jax_key = f"llm/embedder/input_embedding{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
# pop the einsum attention + mlp representations
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
|
||||
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
|
||||
|
||||
for i in range(config.text_config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = (
|
||||
llm_attention_q_einsum[i]
|
||||
.transpose(0, 2, 1)
|
||||
.reshape(
|
||||
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
||||
)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
|
||||
q_proj_weight_reshaped
|
||||
)
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
|
||||
k_proj_weight_reshaped
|
||||
)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
|
||||
v_proj_weight_reshaped
|
||||
)
|
||||
|
||||
o_proj_weight_reshaped = (
|
||||
llm_attention_attn_vec_einsum[i]
|
||||
.transpose(2, 0, 1)
|
||||
.reshape(
|
||||
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
|
||||
)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
|
||||
o_proj_weight_reshaped
|
||||
)
|
||||
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
|
||||
gate_proj_weight.transpose()
|
||||
)
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
|
||||
up_proj_weight.transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
|
||||
llm_mlp_linear[i].transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
|
||||
llm_input_layernorm[i]
|
||||
)
|
||||
state_dict[
|
||||
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
|
||||
] = llm_post_attention_layernorm[i]
|
||||
|
||||
jax_key = f"llm/final_norm/scale{suffix}"
|
||||
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
|
||||
state_dict[pytorch_key] = state_dict.pop(jax_key)
|
||||
|
||||
expert_dict = {}
|
||||
final_state_dict = {}
|
||||
|
||||
# Expert-related keys to extract (including pi05 Dense layer parameters)
|
||||
expert_keys = [
|
||||
f"llm/final_norm_1/scale{suffix}",
|
||||
f"llm/final_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/final_norm_1/Dense_0/kernel{suffix}",
|
||||
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/kv_einsum_1/w{suffix}",
|
||||
f"llm/layers/attn/q_einsum_1/w{suffix}",
|
||||
f"llm/layers/mlp_1/gating_einsum{suffix}",
|
||||
f"llm/layers/mlp_1/linear{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
|
||||
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
if key not in expert_keys:
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
expert_dict[key] = value
|
||||
|
||||
return final_state_dict, expert_dict
|
||||
|
||||
|
||||
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
|
||||
"""Convert Gemma JAX parameters to PyTorch format."""
|
||||
# Add missing attributes to config if they don't exist
|
||||
if not hasattr(config, "vocab_size"):
|
||||
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
|
||||
if not hasattr(config, "hidden_size"):
|
||||
config.hidden_size = config.width
|
||||
if not hasattr(config, "num_hidden_layers"):
|
||||
config.num_hidden_layers = config.depth
|
||||
if not hasattr(config, "num_attention_heads"):
|
||||
config.num_attention_heads = config.num_heads
|
||||
|
||||
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
|
||||
|
||||
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
|
||||
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
|
||||
|
||||
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
|
||||
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
|
||||
|
||||
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization
|
||||
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
llm_input_layernorm_kernel = state_dict.pop(
|
||||
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
|
||||
)
|
||||
llm_post_attention_layernorm_kernel = state_dict.pop(
|
||||
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
|
||||
)
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
|
||||
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
q_proj_weight_reshaped = (
|
||||
llm_attention_q_einsum[i]
|
||||
.transpose(0, 2, 1)
|
||||
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
|
||||
q_proj_weight_reshaped
|
||||
)
|
||||
|
||||
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
|
||||
k_proj_weight_reshaped
|
||||
)
|
||||
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
|
||||
v_proj_weight_reshaped
|
||||
)
|
||||
|
||||
o_proj_weight_reshaped = (
|
||||
llm_attention_attn_vec_einsum[i]
|
||||
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
|
||||
.transpose(1, 0)
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
|
||||
o_proj_weight_reshaped
|
||||
)
|
||||
|
||||
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
|
||||
gate_proj_weight.transpose()
|
||||
)
|
||||
up_proj_weight = llm_mlp_gating_einsum[i, 1]
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
|
||||
up_proj_weight.transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
|
||||
i
|
||||
].transpose()
|
||||
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
|
||||
llm_input_layernorm_bias[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
|
||||
llm_post_attention_layernorm_bias[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
|
||||
llm_input_layernorm_kernel[i].transpose()
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
|
||||
llm_post_attention_layernorm_kernel[i].transpose()
|
||||
)
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
|
||||
llm_input_layernorm[i]
|
||||
)
|
||||
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
|
||||
llm_post_attention_layernorm[i]
|
||||
)
|
||||
|
||||
# Handle final norm layer
|
||||
if "pi05" in checkpoint_dir:
|
||||
# Pi05 with adaptive normalization - use Dense layer parameters directly
|
||||
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
|
||||
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
|
||||
else:
|
||||
# Regular pi0 with standard RMSNorm
|
||||
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
|
||||
f"llm/final_norm_{num_expert}/scale{suffix}"
|
||||
)
|
||||
|
||||
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
|
||||
|
||||
final_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if not isinstance(value, torch.Tensor):
|
||||
final_state_dict[key] = torch.from_numpy(value)
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
|
||||
"""Load and process params by restoring via JAX model loader first.
|
||||
This respects dtype conversions that occur during model restore.
|
||||
"""
|
||||
# Use repository restore utility to load a pure dict of params (value suffix removed)
|
||||
params = openpi.models.model.restore_params(
|
||||
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
|
||||
)
|
||||
|
||||
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
|
||||
|
||||
|
||||
def load_jax_model_and_print_keys(checkpoint_dir: str):
|
||||
"""
|
||||
Load JAX model from checkpoint and print all parameter keys.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the checkpoint directory
|
||||
"""
|
||||
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
|
||||
# Initialize checkpointer
|
||||
checkpointer = ocp.PyTreeCheckpointer()
|
||||
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
|
||||
print(utils.array_tree_to_info(metadata))
|
||||
|
||||
|
||||
def convert_pi0_checkpoint(
|
||||
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
|
||||
):
|
||||
"""
|
||||
Convert PI0 JAX checkpoint to PyTorch format.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the JAX checkpoint
|
||||
precision: Model precision (float32, bfloat16, float16)
|
||||
output_path: Path to save the converted PyTorch model
|
||||
model_config: Model config
|
||||
"""
|
||||
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
|
||||
print(f"Model config: {model_config}")
|
||||
|
||||
# Break down orbax ckpts by restoring via JAX to respect dtype
|
||||
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
|
||||
|
||||
# Process projection params
|
||||
if model_config.pi05:
|
||||
keys = [
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"time_mlp_in",
|
||||
"time_mlp_out",
|
||||
]
|
||||
else:
|
||||
keys = [
|
||||
"state_proj",
|
||||
"action_in_proj",
|
||||
"action_out_proj",
|
||||
"action_time_mlp_in",
|
||||
"action_time_mlp_out",
|
||||
]
|
||||
|
||||
projection_params = {}
|
||||
for key in keys:
|
||||
kernel_params = initial_params["projection_params"][key]["kernel"]
|
||||
bias_params = initial_params["projection_params"][key]["bias"]
|
||||
if isinstance(kernel_params, dict):
|
||||
weight = kernel_params["value"]
|
||||
bias = bias_params["value"]
|
||||
else:
|
||||
weight = kernel_params
|
||||
bias = bias_params
|
||||
|
||||
pytorch_weight_key = f"{key}.weight"
|
||||
pytorch_bias_key = f"{key}.bias"
|
||||
|
||||
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
|
||||
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
|
||||
|
||||
# Create configs based on checkpoint path
|
||||
# All models use the same PaliGemma config structure
|
||||
class PaliGemmaConfig:
|
||||
def __init__(self):
|
||||
self.vision_config = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"hidden_size": 1152,
|
||||
"num_hidden_layers": 27,
|
||||
"num_attention_heads": 16,
|
||||
"intermediate_size": 4304,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 2048,
|
||||
},
|
||||
)()
|
||||
self.text_config = type(
|
||||
"obj",
|
||||
(object,),
|
||||
{
|
||||
"hidden_size": 2048,
|
||||
"num_hidden_layers": 18,
|
||||
"num_attention_heads": 8,
|
||||
"head_dim": 256,
|
||||
"intermediate_size": 16384,
|
||||
},
|
||||
)()
|
||||
|
||||
paligemma_config = PaliGemmaConfig()
|
||||
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
|
||||
|
||||
# Process PaliGemma weights
|
||||
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
|
||||
|
||||
# Process Gemma weights from expert_params
|
||||
gemma_params = slice_gemma_state_dict(
|
||||
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
|
||||
)
|
||||
|
||||
# Instantiate model
|
||||
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
|
||||
|
||||
# Combine all parameters (no prefix needed for our model structure)
|
||||
all_params = {**paligemma_params, **gemma_params, **projection_params}
|
||||
|
||||
# Load state dict
|
||||
pi0_model.load_state_dict(all_params, strict=False)
|
||||
|
||||
if precision == "float32":
|
||||
pi0_model = pi0_model.to(torch.float32)
|
||||
elif precision == "bfloat16":
|
||||
pi0_model = pi0_model.to(torch.bfloat16)
|
||||
else:
|
||||
raise ValueError(f"Invalid precision: {precision}")
|
||||
|
||||
# Save the converted model using safetensors
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
|
||||
# Save model weights as SafeTensors using save_model to handle tied weights
|
||||
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
|
||||
|
||||
# Copy assets folder if it exists
|
||||
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
|
||||
if assets_source.exists():
|
||||
assets_dest = pathlib.Path(output_path) / "assets"
|
||||
if assets_dest.exists():
|
||||
shutil.rmtree(assets_dest)
|
||||
shutil.copytree(assets_source, assets_dest)
|
||||
|
||||
# Save config as JSON for reference
|
||||
config_dict = {
|
||||
"action_dim": model_config.action_dim,
|
||||
"action_horizon": model_config.action_horizon,
|
||||
"paligemma_variant": model_config.paligemma_variant,
|
||||
"action_expert_variant": model_config.action_expert_variant,
|
||||
"precision": precision,
|
||||
}
|
||||
with open(os.path.join(output_path, "config.json"), "w") as f:
|
||||
json.dump(config_dict, f, indent=2)
|
||||
|
||||
print("Model conversion completed successfully!")
|
||||
print(f"Model saved to {output_path}")
|
||||
|
||||
|
||||
def main(
|
||||
checkpoint_dir: str,
|
||||
config_name: str,
|
||||
output_path: str | None = None,
|
||||
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
|
||||
*,
|
||||
inspect_only: bool = False,
|
||||
):
|
||||
"""Load JAX model and optionally convert to PyTorch.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Path to the JAX checkpoint directory
|
||||
output_path: Path to save converted PyTorch model (required for conversion)
|
||||
precision: Precision for model conversion
|
||||
inspect_only: Only inspect parameter keys, don't convert
|
||||
"""
|
||||
model_config = _config.get_config(config_name).model
|
||||
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
|
||||
raise ValueError(f"Config {config_name} is not a Pi0Config")
|
||||
if inspect_only:
|
||||
load_jax_model_and_print_keys(checkpoint_dir)
|
||||
else:
|
||||
if not output_path:
|
||||
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
|
||||
return
|
||||
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
84
policy/openpi-InternData-A1/examples/droid/README.md
Normal file
84
policy/openpi-InternData-A1/examples/droid/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
# DROID Policies in openpi
|
||||
|
||||
We offer instructions for:
|
||||
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
|
||||
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
|
||||
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
|
||||
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
|
||||
|
||||
## Running DROID Inference
|
||||
|
||||
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
|
||||
|
||||
|
||||
### Step 1: Start a policy server
|
||||
|
||||
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
|
||||
|
||||
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
|
||||
2. Start the OpenPI server via the following command:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
|
||||
```
|
||||
|
||||
You can also run the equivalent command below:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env=DROID
|
||||
```
|
||||
|
||||
### Step 2: Run the DROID robot
|
||||
|
||||
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
|
||||
2. On the control laptop, activate your DROID conda environment.
|
||||
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
|
||||
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
|
||||
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
|
||||
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
|
||||
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
|
||||
|
||||
```bash
|
||||
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
|
||||
```
|
||||
|
||||
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
|
||||
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
|
||||
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
|
||||
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
|
||||
|
||||
|
||||
## Running Other Policies
|
||||
|
||||
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
|
||||
|
||||
```
|
||||
# Train from pi0-FAST, using FAST tokenizer
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
|
||||
|
||||
# Train from pi0, using flow matching
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
|
||||
|
||||
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
|
||||
|
||||
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
|
||||
|
||||
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
|
||||
|
||||
# Trained from PaliGemma, using FSQ tokenizer.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
|
||||
|
||||
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
|
||||
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
|
||||
```
|
||||
|
||||
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).
|
||||
106
policy/openpi-InternData-A1/examples/droid/README_train.md
Normal file
106
policy/openpi-InternData-A1/examples/droid/README_train.md
Normal file
@@ -0,0 +1,106 @@
|
||||
# Training on DROID
|
||||
|
||||
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
|
||||
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
|
||||
|
||||
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
|
||||
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
|
||||
|
||||
## Install
|
||||
|
||||
We need a few additional dependencies for RLDS data loading. Run:
|
||||
```bash
|
||||
uv sync --group rlds
|
||||
```
|
||||
|
||||
## Download DROID dataset
|
||||
|
||||
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
|
||||
```
|
||||
|
||||
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
|
||||
|
||||
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
|
||||
|
||||
## Run
|
||||
|
||||
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
|
||||
|
||||
Then, compute normalization statistics (this will take ~10 minutes):
|
||||
```bash
|
||||
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
|
||||
```
|
||||
|
||||
Run training:
|
||||
```bash
|
||||
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
|
||||
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
|
||||
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
|
||||
|
||||
|
||||
## Compute Requirements
|
||||
|
||||
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
|
||||
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
|
||||
|
||||
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
|
||||
|
||||
|
||||
## Data Filtering
|
||||
|
||||
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
|
||||
|
||||
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
|
||||
|
||||
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
|
||||
|
||||
## RoboArena
|
||||
|
||||
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
|
||||
|
||||
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
|
||||
|
||||
|
||||
# Fine-Tuning on Custom DROID Datasets
|
||||
|
||||
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
|
||||
|
||||
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
|
||||
|
||||
|
||||
## Step 1: Converting your custom DROID dataset to LeRobot
|
||||
|
||||
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
|
||||
```
|
||||
|
||||
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
|
||||
```
|
||||
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
|
||||
```
|
||||
|
||||
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
|
||||
|
||||
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
|
||||
```
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
|
||||
```
|
||||
|
||||
## Step 2: Run fine-tuning with your custom dataset
|
||||
|
||||
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
|
||||
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
|
||||
|
||||
To launch training:
|
||||
```
|
||||
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
|
||||
```
|
||||
|
||||
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
|
||||
that should be sampled during training (all others are filtered out).
|
||||
|
||||
Filtering logic:
|
||||
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
|
||||
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
|
||||
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
|
||||
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
|
||||
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
|
||||
|
||||
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
|
||||
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from tqdm import tqdm
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
|
||||
|
||||
builder = tfds.builder_from_directory(
|
||||
# path to the `droid` directory (not its parent)
|
||||
builder_dir="<path_to_droid_dataset_tfds_files>",
|
||||
)
|
||||
ds = builder.as_dataset(split="train", shuffle_files=False)
|
||||
tf.data.experimental.ignore_errors(ds)
|
||||
|
||||
keep_ranges_path = "<path_to_where_to_save_the_json>"
|
||||
|
||||
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
|
||||
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
|
||||
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
|
||||
|
||||
keep_ranges_map = {}
|
||||
if Path(keep_ranges_path).exists():
|
||||
with Path(keep_ranges_path).open("r") as f:
|
||||
keep_ranges_map = json.load(f)
|
||||
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
|
||||
|
||||
for ep_idx, ep in enumerate(tqdm(ds)):
|
||||
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
|
||||
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
|
||||
|
||||
key = f"{recording_folderpath}--{file_path}"
|
||||
if key in keep_ranges_map:
|
||||
continue
|
||||
|
||||
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
|
||||
joint_velocities = np.array(joint_velocities)
|
||||
|
||||
is_idle_array = np.hstack(
|
||||
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
|
||||
)
|
||||
|
||||
# Find what steps go from idle to non-idle and vice-versa
|
||||
is_idle_padded = np.concatenate(
|
||||
[[False], is_idle_array, [False]]
|
||||
) # Start and end with False, so idle at first step is a start of motion
|
||||
|
||||
is_idle_diff = np.diff(is_idle_padded.astype(int))
|
||||
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
|
||||
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
|
||||
|
||||
# Find which steps correspond to idle segments of length at least min_idle_len
|
||||
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
|
||||
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
|
||||
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
|
||||
|
||||
keep_mask = np.ones(len(joint_velocities), dtype=bool)
|
||||
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
|
||||
keep_mask[start:end] = False
|
||||
|
||||
# Get all non-idle ranges of at least 16
|
||||
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
|
||||
keep_padded = np.concatenate([[False], keep_mask, [False]])
|
||||
|
||||
keep_diff = np.diff(keep_padded.astype(int))
|
||||
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
|
||||
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
|
||||
|
||||
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
|
||||
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
|
||||
keep_true_starts = keep_true_starts[true_segment_masks]
|
||||
keep_true_ends = keep_true_ends[true_segment_masks]
|
||||
|
||||
# Add mapping from episode unique ID key to list of non-idle ranges to keep
|
||||
keep_ranges_map[key] = []
|
||||
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
|
||||
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
|
||||
|
||||
if ep_idx % 1000 == 0:
|
||||
with Path(keep_ranges_path).open("w") as f:
|
||||
json.dump(keep_ranges_map, f)
|
||||
|
||||
print("Done!")
|
||||
with Path(keep_ranges_path).open("w") as f:
|
||||
json.dump(keep_ranges_map, f)
|
||||
@@ -0,0 +1,477 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
|
||||
|
||||
Usage:
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
The resulting dataset will get saved to the $LEROBOT_HOME directory.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
import copy
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import h5py
|
||||
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
|
||||
|
||||
def resize_image(image, size):
|
||||
image = Image.fromarray(image)
|
||||
return np.array(image.resize(size, resample=Image.BICUBIC))
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = HF_LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
data_dir = Path(data_dir)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# We will follow the DROID data naming conventions here.
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=15, # DROID data is typically recorded at 15fps
|
||||
features={
|
||||
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
|
||||
"exterior_image_1_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"exterior_image_2_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image_left": {
|
||||
"dtype": "image",
|
||||
"shape": (180, 320, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"joint_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["joint_position"],
|
||||
},
|
||||
"gripper_position": {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["gripper_position"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Load language annotations
|
||||
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
|
||||
with (data_dir / "aggregated-annotations-030724.json").open() as f:
|
||||
language_annotations = json.load(f)
|
||||
|
||||
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
|
||||
# We assume the following directory structure:
|
||||
# RAW_DROID_PATH/
|
||||
# - <...>/
|
||||
# - recordings/
|
||||
# - MP4/
|
||||
# - <camera_id>.mp4 # single-view video of left stereo pair camera
|
||||
# - trajectory.hdf5
|
||||
# - <...>/
|
||||
episode_paths = list(data_dir.glob("**/trajectory.h5"))
|
||||
print(f"Found {len(episode_paths)} episodes for conversion")
|
||||
|
||||
# We will loop over each dataset_name and write episodes to the LeRobot dataset
|
||||
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
|
||||
# Load raw data
|
||||
recording_folderpath = episode_path.parent / "recordings" / "MP4"
|
||||
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
|
||||
|
||||
# To load the language instruction, we need to parse out the episode_id from the metadata file
|
||||
# Again, you can modify this step for your own data, to load your own language instructions
|
||||
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
|
||||
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
|
||||
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
|
||||
"language_instruction1"
|
||||
]
|
||||
print(f"Converting episode with language instruction: {language_instruction}")
|
||||
|
||||
# Write to LeRobot dataset
|
||||
for step in trajectory:
|
||||
camera_type_dict = step["observation"]["camera_type"]
|
||||
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
|
||||
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
|
||||
dataset.add_frame(
|
||||
{
|
||||
# Note: need to flip BGR --> RGB for loaded images
|
||||
"exterior_image_1_left": resize_image(
|
||||
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
|
||||
),
|
||||
"exterior_image_2_left": resize_image(
|
||||
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
|
||||
),
|
||||
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
|
||||
"joint_position": np.asarray(
|
||||
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
|
||||
),
|
||||
"gripper_position": np.asarray(
|
||||
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
|
||||
),
|
||||
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
|
||||
"actions": np.concatenate(
|
||||
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
|
||||
),
|
||||
"task": language_instruction,
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
##########################################################################################################
|
||||
################ The rest of this file are functions to parse the raw DROID data #########################
|
||||
################ You don't need to worry about understanding this part #########################
|
||||
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
|
||||
##########################################################################################################
|
||||
|
||||
|
||||
camera_type_dict = {
|
||||
"hand_camera_id": 0,
|
||||
"varied_camera_1_id": 1,
|
||||
"varied_camera_2_id": 1,
|
||||
}
|
||||
|
||||
camera_type_to_string_dict = {
|
||||
0: "hand_camera",
|
||||
1: "varied_camera",
|
||||
2: "fixed_camera",
|
||||
}
|
||||
|
||||
|
||||
def get_camera_type(cam_id):
|
||||
if cam_id not in camera_type_dict:
|
||||
return None
|
||||
type_int = camera_type_dict[cam_id]
|
||||
return camera_type_to_string_dict[type_int]
|
||||
|
||||
|
||||
class MP4Reader:
|
||||
def __init__(self, filepath, serial_number):
|
||||
# Save Parameters #
|
||||
self.serial_number = serial_number
|
||||
self._index = 0
|
||||
|
||||
# Open Video Reader #
|
||||
self._mp4_reader = cv2.VideoCapture(filepath)
|
||||
if not self._mp4_reader.isOpened():
|
||||
raise RuntimeError("Corrupted MP4 File")
|
||||
|
||||
def set_reading_parameters(
|
||||
self,
|
||||
image=True, # noqa: FBT002
|
||||
concatenate_images=False, # noqa: FBT002
|
||||
resolution=(0, 0),
|
||||
resize_func=None,
|
||||
):
|
||||
# Save Parameters #
|
||||
self.image = image
|
||||
self.concatenate_images = concatenate_images
|
||||
self.resolution = resolution
|
||||
self.resize_func = cv2.resize
|
||||
self.skip_reading = not image
|
||||
if self.skip_reading:
|
||||
return
|
||||
|
||||
def get_frame_resolution(self):
|
||||
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
|
||||
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
|
||||
return (width, height)
|
||||
|
||||
def get_frame_count(self):
|
||||
if self.skip_reading:
|
||||
return 0
|
||||
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
|
||||
|
||||
def set_frame_index(self, index):
|
||||
if self.skip_reading:
|
||||
return
|
||||
|
||||
if index < self._index:
|
||||
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
|
||||
self._index = index
|
||||
|
||||
while self._index < index:
|
||||
self.read_camera(ignore_data=True)
|
||||
|
||||
def _process_frame(self, frame):
|
||||
frame = copy.deepcopy(frame)
|
||||
if self.resolution == (0, 0):
|
||||
return frame
|
||||
return self.resize_func(frame, self.resolution)
|
||||
|
||||
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
|
||||
# Skip if Read Unnecesary #
|
||||
if self.skip_reading:
|
||||
return {}
|
||||
|
||||
# Read Camera #
|
||||
success, frame = self._mp4_reader.read()
|
||||
|
||||
self._index += 1
|
||||
if not success:
|
||||
return None
|
||||
if ignore_data:
|
||||
return None
|
||||
|
||||
# Return Data #
|
||||
data_dict = {}
|
||||
|
||||
if self.concatenate_images or "stereo" not in self.serial_number:
|
||||
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
|
||||
else:
|
||||
single_width = frame.shape[1] // 2
|
||||
data_dict["image"] = {
|
||||
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
|
||||
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
|
||||
}
|
||||
|
||||
return data_dict
|
||||
|
||||
def disable_camera(self):
|
||||
if hasattr(self, "_mp4_reader"):
|
||||
self._mp4_reader.release()
|
||||
|
||||
|
||||
class RecordedMultiCameraWrapper:
|
||||
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
|
||||
# Save Camera Info #
|
||||
self.camera_kwargs = camera_kwargs
|
||||
|
||||
# Open Camera Readers #
|
||||
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
|
||||
all_filepaths = mp4_filepaths
|
||||
|
||||
self.camera_dict = {}
|
||||
for f in all_filepaths:
|
||||
serial_number = f.split("/")[-1][:-4]
|
||||
cam_type = get_camera_type(serial_number)
|
||||
camera_kwargs.get(cam_type, {})
|
||||
|
||||
if f.endswith(".mp4"):
|
||||
Reader = MP4Reader # noqa: N806
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.camera_dict[serial_number] = Reader(f, serial_number)
|
||||
|
||||
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
|
||||
full_obs_dict = defaultdict(dict)
|
||||
|
||||
# Read Cameras In Randomized Order #
|
||||
all_cam_ids = list(self.camera_dict.keys())
|
||||
# random.shuffle(all_cam_ids)
|
||||
|
||||
for cam_id in all_cam_ids:
|
||||
if "stereo" in cam_id:
|
||||
continue
|
||||
try:
|
||||
cam_type = camera_type_dict[cam_id]
|
||||
except KeyError:
|
||||
print(f"{self.camera_dict} -- {camera_type_dict}")
|
||||
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
|
||||
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
|
||||
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
|
||||
|
||||
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
|
||||
if index is not None:
|
||||
self.camera_dict[cam_id].set_frame_index(index)
|
||||
|
||||
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
|
||||
|
||||
# Process Returned Data #
|
||||
if data_dict is None:
|
||||
return None
|
||||
for key in data_dict:
|
||||
full_obs_dict[key].update(data_dict[key])
|
||||
|
||||
return full_obs_dict
|
||||
|
||||
|
||||
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
|
||||
length = None
|
||||
|
||||
for key in hdf5_file:
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
|
||||
curr_data = hdf5_file[key]
|
||||
if isinstance(curr_data, h5py.Group):
|
||||
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
|
||||
elif isinstance(curr_data, h5py.Dataset):
|
||||
curr_length = len(curr_data)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
if length is None:
|
||||
length = curr_length
|
||||
assert curr_length == length
|
||||
|
||||
return length
|
||||
|
||||
|
||||
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
|
||||
data_dict = {}
|
||||
|
||||
for key in hdf5_file:
|
||||
if key in keys_to_ignore:
|
||||
continue
|
||||
|
||||
curr_data = hdf5_file[key]
|
||||
if isinstance(curr_data, h5py.Group):
|
||||
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
|
||||
elif isinstance(curr_data, h5py.Dataset):
|
||||
data_dict[key] = curr_data[index]
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
class TrajectoryReader:
|
||||
def __init__(self, filepath, read_images=True): # noqa: FBT002
|
||||
self._hdf5_file = h5py.File(filepath, "r")
|
||||
is_video_folder = "observations/videos" in self._hdf5_file
|
||||
self._read_images = read_images and is_video_folder
|
||||
self._length = get_hdf5_length(self._hdf5_file)
|
||||
self._video_readers = {}
|
||||
self._index = 0
|
||||
|
||||
def length(self):
|
||||
return self._length
|
||||
|
||||
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
|
||||
# Make Sure We Read Within Range #
|
||||
if index is None:
|
||||
index = self._index
|
||||
else:
|
||||
assert not self._read_images
|
||||
self._index = index
|
||||
assert index < self._length
|
||||
|
||||
# Load Low Dimensional Data #
|
||||
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
|
||||
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
|
||||
|
||||
# Increment Read Index #
|
||||
self._index += 1
|
||||
|
||||
# Return Timestep #
|
||||
return timestep
|
||||
|
||||
def close(self):
|
||||
self._hdf5_file.close()
|
||||
|
||||
|
||||
def load_trajectory(
|
||||
filepath=None,
|
||||
read_cameras=True, # noqa: FBT002
|
||||
recording_folderpath=None,
|
||||
camera_kwargs={}, # noqa: B006
|
||||
remove_skipped_steps=False, # noqa: FBT002
|
||||
num_samples_per_traj=None,
|
||||
num_samples_per_traj_coeff=1.5,
|
||||
):
|
||||
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
|
||||
|
||||
traj_reader = TrajectoryReader(filepath)
|
||||
if read_recording_folderpath:
|
||||
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
|
||||
|
||||
horizon = traj_reader.length()
|
||||
timestep_list = []
|
||||
|
||||
# Choose Timesteps To Save #
|
||||
if num_samples_per_traj:
|
||||
num_to_save = num_samples_per_traj
|
||||
if remove_skipped_steps:
|
||||
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
|
||||
max_size = min(num_to_save, horizon)
|
||||
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
|
||||
else:
|
||||
indices_to_save = np.arange(horizon)
|
||||
|
||||
# Iterate Over Trajectory #
|
||||
for i in indices_to_save:
|
||||
# Get HDF5 Data #
|
||||
timestep = traj_reader.read_timestep(index=i)
|
||||
|
||||
# If Applicable, Get Recorded Data #
|
||||
if read_recording_folderpath:
|
||||
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
|
||||
camera_type_dict = {
|
||||
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
|
||||
}
|
||||
camera_obs = camera_reader.read_cameras(
|
||||
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
|
||||
)
|
||||
camera_failed = camera_obs is None
|
||||
|
||||
# Add Data To Timestep If Successful #
|
||||
if camera_failed:
|
||||
break
|
||||
timestep["observation"].update(camera_obs)
|
||||
|
||||
# Filter Steps #
|
||||
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
|
||||
delete_skipped_step = step_skipped and remove_skipped_steps
|
||||
|
||||
# Save Filtered Timesteps #
|
||||
if delete_skipped_step:
|
||||
del timestep
|
||||
else:
|
||||
timestep_list.append(timestep)
|
||||
|
||||
# Remove Extra Transitions #
|
||||
timestep_list = np.array(timestep_list)
|
||||
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
|
||||
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
|
||||
timestep_list = timestep_list[ind_to_keep]
|
||||
|
||||
# Close Readers #
|
||||
traj_reader.close()
|
||||
|
||||
# Return Data #
|
||||
return timestep_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
246
policy/openpi-InternData-A1/examples/droid/main.py
Normal file
246
policy/openpi-InternData-A1/examples/droid/main.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# ruff: noqa
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import datetime
|
||||
import faulthandler
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from droid.robot_env import RobotEnv
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
faulthandler.enable()
|
||||
|
||||
# DROID data collection frequency -- we slow down execution to match this frequency
|
||||
DROID_CONTROL_FREQUENCY = 15
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
# Hardware parameters
|
||||
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
|
||||
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
|
||||
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
|
||||
|
||||
# Policy parameters
|
||||
external_camera: str | None = (
|
||||
None # which external camera should be fed to the policy, choose from ["left", "right"]
|
||||
)
|
||||
|
||||
# Rollout parameters
|
||||
max_timesteps: int = 600
|
||||
# How many actions to execute from a predicted action chunk before querying policy server again
|
||||
# 8 is usually a good default (equals 0.5 seconds of action execution).
|
||||
open_loop_horizon: int = 8
|
||||
|
||||
# Remote server parameters
|
||||
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
|
||||
remote_port: int = (
|
||||
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
|
||||
)
|
||||
|
||||
|
||||
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
|
||||
# waiting for a new action chunk, it will raise an exception and the server connection dies.
|
||||
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
|
||||
@contextlib.contextmanager
|
||||
def prevent_keyboard_interrupt():
|
||||
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
|
||||
interrupted = False
|
||||
original_handler = signal.getsignal(signal.SIGINT)
|
||||
|
||||
def handler(signum, frame):
|
||||
nonlocal interrupted
|
||||
interrupted = True
|
||||
|
||||
signal.signal(signal.SIGINT, handler)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_handler)
|
||||
if interrupted:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
def main(args: Args):
|
||||
# Make sure external camera is specified by user -- we only use one external camera for the policy
|
||||
assert (
|
||||
args.external_camera is not None and args.external_camera in ["left", "right"]
|
||||
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
|
||||
|
||||
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
|
||||
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
|
||||
print("Created the droid env!")
|
||||
|
||||
# Connect to the policy server
|
||||
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
|
||||
|
||||
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
|
||||
|
||||
while True:
|
||||
instruction = input("Enter instruction: ")
|
||||
|
||||
# Rollout parameters
|
||||
actions_from_chunk_completed = 0
|
||||
pred_action_chunk = None
|
||||
|
||||
# Prepare to save video of rollout
|
||||
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
|
||||
video = []
|
||||
bar = tqdm.tqdm(range(args.max_timesteps))
|
||||
print("Running rollout... press Ctrl+C to stop early.")
|
||||
for t_step in bar:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Get the current observation
|
||||
curr_obs = _extract_observation(
|
||||
args,
|
||||
env.get_observation(),
|
||||
# Save the first observation to disk
|
||||
save_to_disk=t_step == 0,
|
||||
)
|
||||
|
||||
video.append(curr_obs[f"{args.external_camera}_image"])
|
||||
|
||||
# Send websocket request to policy server if it's time to predict a new chunk
|
||||
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
|
||||
actions_from_chunk_completed = 0
|
||||
|
||||
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
|
||||
# and improve latency.
|
||||
request_data = {
|
||||
"observation/exterior_image_1_left": image_tools.resize_with_pad(
|
||||
curr_obs[f"{args.external_camera}_image"], 224, 224
|
||||
),
|
||||
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
|
||||
"observation/joint_position": curr_obs["joint_position"],
|
||||
"observation/gripper_position": curr_obs["gripper_position"],
|
||||
"prompt": instruction,
|
||||
}
|
||||
|
||||
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
|
||||
# Ctrl+C will be handled after the server call is complete
|
||||
with prevent_keyboard_interrupt():
|
||||
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
|
||||
pred_action_chunk = policy_client.infer(request_data)["actions"]
|
||||
assert pred_action_chunk.shape == (10, 8)
|
||||
|
||||
# Select current action to execute from chunk
|
||||
action = pred_action_chunk[actions_from_chunk_completed]
|
||||
actions_from_chunk_completed += 1
|
||||
|
||||
# Binarize gripper action
|
||||
if action[-1].item() > 0.5:
|
||||
# action[-1] = 1.0
|
||||
action = np.concatenate([action[:-1], np.ones((1,))])
|
||||
else:
|
||||
# action[-1] = 0.0
|
||||
action = np.concatenate([action[:-1], np.zeros((1,))])
|
||||
|
||||
# clip all dimensions of action to [-1, 1]
|
||||
action = np.clip(action, -1, 1)
|
||||
|
||||
env.step(action)
|
||||
|
||||
# Sleep to match DROID data collection frequency
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
|
||||
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
video = np.stack(video)
|
||||
save_filename = "video_" + timestamp
|
||||
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
|
||||
|
||||
success: str | float | None = None
|
||||
while not isinstance(success, float):
|
||||
success = input(
|
||||
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
|
||||
)
|
||||
if success == "y":
|
||||
success = 1.0
|
||||
elif success == "n":
|
||||
success = 0.0
|
||||
|
||||
success = float(success) / 100
|
||||
if not (0 <= success <= 1):
|
||||
print(f"Success must be a number in [0, 100] but got: {success * 100}")
|
||||
|
||||
df = df.append(
|
||||
{
|
||||
"success": success,
|
||||
"duration": t_step,
|
||||
"video_filename": save_filename,
|
||||
},
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
if input("Do one more eval? (enter y or n) ").lower() != "y":
|
||||
break
|
||||
env.reset()
|
||||
|
||||
os.makedirs("results", exist_ok=True)
|
||||
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
|
||||
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
|
||||
df.to_csv(csv_filename)
|
||||
print(f"Results saved to {csv_filename}")
|
||||
|
||||
|
||||
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
|
||||
image_observations = obs_dict["image"]
|
||||
left_image, right_image, wrist_image = None, None, None
|
||||
for key in image_observations:
|
||||
# Note the "left" below refers to the left camera in the stereo pair.
|
||||
# The model is only trained on left stereo cams, so we only feed those.
|
||||
if args.left_camera_id in key and "left" in key:
|
||||
left_image = image_observations[key]
|
||||
elif args.right_camera_id in key and "left" in key:
|
||||
right_image = image_observations[key]
|
||||
elif args.wrist_camera_id in key and "left" in key:
|
||||
wrist_image = image_observations[key]
|
||||
|
||||
# Drop the alpha dimension
|
||||
left_image = left_image[..., :3]
|
||||
right_image = right_image[..., :3]
|
||||
wrist_image = wrist_image[..., :3]
|
||||
|
||||
# Convert to RGB
|
||||
left_image = left_image[..., ::-1]
|
||||
right_image = right_image[..., ::-1]
|
||||
wrist_image = wrist_image[..., ::-1]
|
||||
|
||||
# In addition to image observations, also capture the proprioceptive state
|
||||
robot_state = obs_dict["robot_state"]
|
||||
cartesian_position = np.array(robot_state["cartesian_position"])
|
||||
joint_position = np.array(robot_state["joint_positions"])
|
||||
gripper_position = np.array([robot_state["gripper_position"]])
|
||||
|
||||
# Save the images to disk so that they can be viewed live while the robot is running
|
||||
# Create one combined image to make live viewing easy
|
||||
if save_to_disk:
|
||||
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
|
||||
combined_image = Image.fromarray(combined_image)
|
||||
combined_image.save("robot_camera_views.png")
|
||||
|
||||
return {
|
||||
"left_image": left_image,
|
||||
"right_image": right_image,
|
||||
"wrist_image": wrist_image,
|
||||
"cartesian_position": cartesian_position,
|
||||
"joint_position": joint_position,
|
||||
"gripper_position": gripper_position,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args: Args = tyro.cli(Args)
|
||||
main(args)
|
||||
137
policy/openpi-InternData-A1/examples/inference.ipynb
Normal file
137
policy/openpi-InternData-A1/examples/inference.ipynb
Normal file
@@ -0,0 +1,137 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import dataclasses\n",
|
||||
"\n",
|
||||
"import jax\n",
|
||||
"\n",
|
||||
"from openpi.models import model as _model\n",
|
||||
"from openpi.policies import droid_policy\n",
|
||||
"from openpi.policies import policy_config as _policy_config\n",
|
||||
"from openpi.shared import download\n",
|
||||
"from openpi.training import config as _config\n",
|
||||
"from openpi.training import data_loader as _data_loader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Policy inference\n",
|
||||
"\n",
|
||||
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_fast_droid\")\n",
|
||||
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
|
||||
"\n",
|
||||
"# Create a trained policy.\n",
|
||||
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
|
||||
"\n",
|
||||
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
|
||||
"example = droid_policy.make_droid_example()\n",
|
||||
"result = policy.infer(example)\n",
|
||||
"\n",
|
||||
"# Delete the policy to free up memory.\n",
|
||||
"del policy\n",
|
||||
"\n",
|
||||
"print(\"Actions shape:\", result[\"actions\"].shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Working with a live model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = _config.get_config(\"pi0_aloha_sim\")\n",
|
||||
"\n",
|
||||
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
|
||||
"key = jax.random.key(0)\n",
|
||||
"\n",
|
||||
"# Create a model from the checkpoint.\n",
|
||||
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
|
||||
"\n",
|
||||
"# We can create fake observations and actions to test the model.\n",
|
||||
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reduce the batch size to reduce memory usage.\n",
|
||||
"config = dataclasses.replace(config, batch_size=2)\n",
|
||||
"\n",
|
||||
"# Load a single batch of data. This is the same data that will be used during training.\n",
|
||||
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
|
||||
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
|
||||
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
|
||||
"obs, act = next(iter(loader))\n",
|
||||
"\n",
|
||||
"# Sample actions from the model.\n",
|
||||
"loss = model.compute_loss(key, obs, act)\n",
|
||||
"\n",
|
||||
"# Delete the model to free up memory.\n",
|
||||
"del model\n",
|
||||
"\n",
|
||||
"print(\"Loss shape:\", loss.shape)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
59
policy/openpi-InternData-A1/examples/libero/Dockerfile
Normal file
59
policy/openpi-InternData-A1/examples/libero/Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# Dockerfile for the LIBERO benchmark.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t libero -f examples/libero/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y \
|
||||
make \
|
||||
g++ \
|
||||
clang \
|
||||
libosmesa6-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglew-dev \
|
||||
libglfw3-dev \
|
||||
libgles2-mesa-dev \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxrender1 \
|
||||
libxext6
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
|
||||
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
|
||||
|
||||
# Create a default config file to avoid an input prompt from LIBERO's init script.
|
||||
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
|
||||
ENV LIBERO_CONFIG_PATH=/tmp/libero
|
||||
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
|
||||
benchmark_root: /app/third_party/libero/libero/libero
|
||||
bddl_files: /app/third_party/libero/libero/libero/bddl_files
|
||||
init_states: /app/third_party/libero/libero/libero/init_files
|
||||
datasets: /app/third_party/libero/libero/datasets
|
||||
assets: /app/third_party/libero/libero/libero/assets
|
||||
EOF
|
||||
|
||||
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]
|
||||
71
policy/openpi-InternData-A1/examples/libero/README.md
Normal file
71
policy/openpi-InternData-A1/examples/libero/README.md
Normal file
@@ -0,0 +1,71 @@
|
||||
# LIBERO Benchmark
|
||||
|
||||
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
|
||||
|
||||
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
|
||||
|
||||
This example requires git submodules to be initialized. Don't forget to run:
|
||||
|
||||
```bash
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
## With Docker (recommended)
|
||||
|
||||
```bash
|
||||
# Grant access to the X11 server:
|
||||
sudo xhost +local:docker
|
||||
|
||||
# To run with the default checkpoint and task suite:
|
||||
SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
||||
|
||||
# To run with glx for Mujoco instead (use this if you have egl errors):
|
||||
MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
|
||||
```
|
||||
|
||||
You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
|
||||
For example:
|
||||
|
||||
```bash
|
||||
# To load a custom checkpoint (located in the top-level openpi/ directory):
|
||||
export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
|
||||
|
||||
# To run the libero_10 task suite:
|
||||
export CLIENT_ARGS="--args.task-suite-name libero_10"
|
||||
```
|
||||
|
||||
## Without Docker (not recommended)
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
# Create virtual environment
|
||||
uv venv --python 3.8 examples/libero/.venv
|
||||
source examples/libero/.venv/bin/activate
|
||||
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
|
||||
uv pip install -e packages/openpi-client
|
||||
uv pip install -e third_party/libero
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
|
||||
|
||||
# Run the simulation
|
||||
python examples/libero/main.py
|
||||
|
||||
# To run with glx for Mujoco instead (use this if you have egl errors):
|
||||
MUJOCO_GL=glx python examples/libero/main.py
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
# Run the server
|
||||
uv run scripts/serve_policy.py --env LIBERO
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
|
||||
checkpoint was trained in openpi with the `pi05_libero` config.
|
||||
|
||||
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|
||||
|-------|---------------|---------------|-------------|-----------|---------|
|
||||
| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85
|
||||
54
policy/openpi-InternData-A1/examples/libero/compose.yml
Normal file
54
policy/openpi-InternData-A1/examples/libero/compose.yml
Normal file
@@ -0,0 +1,54 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/libero/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: libero
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/libero/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
privileged: true
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ../../data:/data
|
||||
- /tmp/.X11-unix:/tmp/.X11-unix:ro
|
||||
environment:
|
||||
- CLIENT_ARGS
|
||||
- DISPLAY=$DISPLAY
|
||||
- MUJOCO_GL=${MUJOCO_GL:-egl}
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
Minimal example script for converting a dataset to LeRobot format.
|
||||
|
||||
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
|
||||
modified for any other data you have saved in a custom format.
|
||||
|
||||
Usage:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
|
||||
|
||||
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
|
||||
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
|
||||
|
||||
Note: to run the script, you need to install tensorflow_datasets:
|
||||
`uv pip install tensorflow tensorflow_datasets`
|
||||
|
||||
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
|
||||
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
|
||||
Running this conversion script will take approximately 30 minutes.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
import tensorflow_datasets as tfds
|
||||
import tyro
|
||||
|
||||
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
|
||||
RAW_DATASET_NAMES = [
|
||||
"libero_10_no_noops",
|
||||
"libero_goal_no_noops",
|
||||
"libero_object_no_noops",
|
||||
"libero_spatial_no_noops",
|
||||
] # For simplicity we will combine multiple Libero datasets into one training dataset
|
||||
|
||||
|
||||
def main(data_dir: str, *, push_to_hub: bool = False):
|
||||
# Clean up any existing dataset in the output directory
|
||||
output_path = HF_LEROBOT_HOME / REPO_NAME
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
# Create LeRobot dataset, define features to store
|
||||
# OpenPi assumes that proprio is stored in `state` and actions in `action`
|
||||
# LeRobot assumes that dtype of image data is `image`
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=REPO_NAME,
|
||||
robot_type="panda",
|
||||
fps=10,
|
||||
features={
|
||||
"image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"wrist_image": {
|
||||
"dtype": "image",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "channel"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": ["state"],
|
||||
},
|
||||
"actions": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": ["actions"],
|
||||
},
|
||||
},
|
||||
image_writer_threads=10,
|
||||
image_writer_processes=5,
|
||||
)
|
||||
|
||||
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
|
||||
# You can modify this for your own data format
|
||||
for raw_dataset_name in RAW_DATASET_NAMES:
|
||||
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
|
||||
for episode in raw_dataset:
|
||||
for step in episode["steps"].as_numpy_iterator():
|
||||
dataset.add_frame(
|
||||
{
|
||||
"image": step["observation"]["image"],
|
||||
"wrist_image": step["observation"]["wrist_image"],
|
||||
"state": step["observation"]["state"],
|
||||
"actions": step["action"],
|
||||
"task": step["language_instruction"].decode(),
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
|
||||
# Optionally push to the Hugging Face Hub
|
||||
if push_to_hub:
|
||||
dataset.push_to_hub(
|
||||
tags=["libero", "panda", "rlds"],
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(main)
|
||||
219
policy/openpi-InternData-A1/examples/libero/main.py
Normal file
219
policy/openpi-InternData-A1/examples/libero/main.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import pathlib
|
||||
|
||||
import imageio
|
||||
from libero.libero import benchmark
|
||||
from libero.libero import get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
import numpy as np
|
||||
from openpi_client import image_tools
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
|
||||
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
#################################################################################################################
|
||||
# Model server parameters
|
||||
#################################################################################################################
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
resize_size: int = 224
|
||||
replan_steps: int = 5
|
||||
|
||||
#################################################################################################################
|
||||
# LIBERO environment-specific parameters
|
||||
#################################################################################################################
|
||||
task_suite_name: str = (
|
||||
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
|
||||
)
|
||||
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
|
||||
num_trials_per_task: int = 50 # Number of rollouts per task
|
||||
|
||||
#################################################################################################################
|
||||
# Utils
|
||||
#################################################################################################################
|
||||
video_out_path: str = "data/libero/videos" # Path to save videos
|
||||
|
||||
seed: int = 7 # Random Seed (for reproducibility)
|
||||
|
||||
|
||||
def eval_libero(args: Args) -> None:
|
||||
# Set random seed
|
||||
np.random.seed(args.seed)
|
||||
|
||||
# Initialize LIBERO task suite
|
||||
benchmark_dict = benchmark.get_benchmark_dict()
|
||||
task_suite = benchmark_dict[args.task_suite_name]()
|
||||
num_tasks_in_suite = task_suite.n_tasks
|
||||
logging.info(f"Task suite: {args.task_suite_name}")
|
||||
|
||||
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if args.task_suite_name == "libero_spatial":
|
||||
max_steps = 220 # longest training demo has 193 steps
|
||||
elif args.task_suite_name == "libero_object":
|
||||
max_steps = 280 # longest training demo has 254 steps
|
||||
elif args.task_suite_name == "libero_goal":
|
||||
max_steps = 300 # longest training demo has 270 steps
|
||||
elif args.task_suite_name == "libero_10":
|
||||
max_steps = 520 # longest training demo has 505 steps
|
||||
elif args.task_suite_name == "libero_90":
|
||||
max_steps = 400 # longest training demo has 373 steps
|
||||
else:
|
||||
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
|
||||
|
||||
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
|
||||
|
||||
# Start evaluation
|
||||
total_episodes, total_successes = 0, 0
|
||||
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
|
||||
# Get task
|
||||
task = task_suite.get_task(task_id)
|
||||
|
||||
# Get default LIBERO initial states
|
||||
initial_states = task_suite.get_task_init_states(task_id)
|
||||
|
||||
# Initialize LIBERO environment and task description
|
||||
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
|
||||
|
||||
# Start episodes
|
||||
task_episodes, task_successes = 0, 0
|
||||
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
|
||||
logging.info(f"\nTask: {task_description}")
|
||||
|
||||
# Reset environment
|
||||
env.reset()
|
||||
action_plan = collections.deque()
|
||||
|
||||
# Set initial states
|
||||
obs = env.set_init_state(initial_states[episode_idx])
|
||||
|
||||
# Setup
|
||||
t = 0
|
||||
replay_images = []
|
||||
|
||||
logging.info(f"Starting episode {task_episodes+1}...")
|
||||
while t < max_steps + args.num_steps_wait:
|
||||
try:
|
||||
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
|
||||
# and we need to wait for them to fall
|
||||
if t < args.num_steps_wait:
|
||||
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
|
||||
t += 1
|
||||
continue
|
||||
|
||||
# Get preprocessed image
|
||||
# IMPORTANT: rotate 180 degrees to match train preprocessing
|
||||
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
|
||||
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
|
||||
img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
|
||||
)
|
||||
wrist_img = image_tools.convert_to_uint8(
|
||||
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
|
||||
)
|
||||
|
||||
# Save preprocessed image for replay video
|
||||
replay_images.append(img)
|
||||
|
||||
if not action_plan:
|
||||
# Finished executing previous action chunk -- compute new chunk
|
||||
# Prepare observations dict
|
||||
element = {
|
||||
"observation/image": img,
|
||||
"observation/wrist_image": wrist_img,
|
||||
"observation/state": np.concatenate(
|
||||
(
|
||||
obs["robot0_eef_pos"],
|
||||
_quat2axisangle(obs["robot0_eef_quat"]),
|
||||
obs["robot0_gripper_qpos"],
|
||||
)
|
||||
),
|
||||
"prompt": str(task_description),
|
||||
}
|
||||
|
||||
# Query model to get action
|
||||
action_chunk = client.infer(element)["actions"]
|
||||
assert (
|
||||
len(action_chunk) >= args.replan_steps
|
||||
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
|
||||
action_plan.extend(action_chunk[: args.replan_steps])
|
||||
|
||||
action = action_plan.popleft()
|
||||
|
||||
# Execute action in environment
|
||||
obs, reward, done, info = env.step(action.tolist())
|
||||
if done:
|
||||
task_successes += 1
|
||||
total_successes += 1
|
||||
break
|
||||
t += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Caught exception: {e}")
|
||||
break
|
||||
|
||||
task_episodes += 1
|
||||
total_episodes += 1
|
||||
|
||||
# Save a replay video of the episode
|
||||
suffix = "success" if done else "failure"
|
||||
task_segment = task_description.replace(" ", "_")
|
||||
imageio.mimwrite(
|
||||
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
|
||||
[np.asarray(x) for x in replay_images],
|
||||
fps=10,
|
||||
)
|
||||
|
||||
# Log current results
|
||||
logging.info(f"Success: {done}")
|
||||
logging.info(f"# episodes completed so far: {total_episodes}")
|
||||
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
|
||||
|
||||
# Log final results
|
||||
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
|
||||
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
|
||||
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
|
||||
logging.info(f"Total episodes: {total_episodes}")
|
||||
|
||||
|
||||
def _get_libero_env(task, resolution, seed):
|
||||
"""Initializes and returns the LIBERO environment, along with the task description."""
|
||||
task_description = task.language
|
||||
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
|
||||
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
|
||||
return env, task_description
|
||||
|
||||
|
||||
def _quat2axisangle(quat):
|
||||
"""
|
||||
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
tyro.cli(eval_libero)
|
||||
11
policy/openpi-InternData-A1/examples/libero/requirements.in
Normal file
11
policy/openpi-InternData-A1/examples/libero/requirements.in
Normal file
@@ -0,0 +1,11 @@
|
||||
imageio[ffmpeg]
|
||||
numpy==1.22.4
|
||||
tqdm
|
||||
tyro
|
||||
PyYaml
|
||||
opencv-python==4.6.0.66
|
||||
torch==1.11.0+cu113
|
||||
torchvision==0.12.0+cu113
|
||||
torchaudio==0.11.0+cu113
|
||||
robosuite==1.4.1
|
||||
matplotlib==3.5.3
|
||||
136
policy/openpi-InternData-A1/examples/libero/requirements.txt
Normal file
136
policy/openpi-InternData-A1/examples/libero/requirements.txt
Normal file
@@ -0,0 +1,136 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
|
||||
absl-py==2.1.0
|
||||
# via mujoco
|
||||
certifi==2024.12.14
|
||||
# via requests
|
||||
charset-normalizer==3.4.0
|
||||
# via requests
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
etils==1.3.0
|
||||
# via mujoco
|
||||
eval-type-backport==0.2.0
|
||||
# via tyro
|
||||
evdev==1.7.1
|
||||
# via pynput
|
||||
fonttools==4.55.3
|
||||
# via matplotlib
|
||||
glfw==1.12.0
|
||||
# via mujoco
|
||||
idna==3.10
|
||||
# via requests
|
||||
imageio==2.35.1
|
||||
# via -r examples/libero/requirements.in
|
||||
imageio-ffmpeg==0.5.1
|
||||
# via imageio
|
||||
importlib-metadata==8.5.0
|
||||
# via typeguard
|
||||
importlib-resources==6.4.5
|
||||
# via etils
|
||||
kiwisolver==1.4.7
|
||||
# via matplotlib
|
||||
llvmlite==0.36.0
|
||||
# via numba
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
matplotlib==3.5.3
|
||||
# via -r examples/libero/requirements.in
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mujoco==3.2.3
|
||||
# via robosuite
|
||||
numba==0.53.1
|
||||
# via robosuite
|
||||
numpy==1.22.4
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# imageio
|
||||
# matplotlib
|
||||
# mujoco
|
||||
# numba
|
||||
# opencv-python
|
||||
# robosuite
|
||||
# scipy
|
||||
# torchvision
|
||||
opencv-python==4.6.0.66
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# robosuite
|
||||
packaging==24.2
|
||||
# via matplotlib
|
||||
pillow==10.4.0
|
||||
# via
|
||||
# imageio
|
||||
# matplotlib
|
||||
# robosuite
|
||||
# torchvision
|
||||
psutil==6.1.0
|
||||
# via imageio
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pynput==1.7.7
|
||||
# via robosuite
|
||||
pyopengl==3.1.7
|
||||
# via mujoco
|
||||
pyparsing==3.1.4
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via matplotlib
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pyyaml==6.0.2
|
||||
# via -r examples/libero/requirements.in
|
||||
requests==2.32.3
|
||||
# via torchvision
|
||||
rich==13.9.4
|
||||
# via tyro
|
||||
robosuite==1.4.1
|
||||
# via -r examples/libero/requirements.in
|
||||
scipy==1.10.1
|
||||
# via robosuite
|
||||
setuptools==75.3.0
|
||||
# via
|
||||
# imageio-ffmpeg
|
||||
# numba
|
||||
shtab==1.7.1
|
||||
# via tyro
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
termcolor==2.4.0
|
||||
# via robosuite
|
||||
torch==1.11.0+cu113
|
||||
# via
|
||||
# -r examples/libero/requirements.in
|
||||
# torchaudio
|
||||
# torchvision
|
||||
torchaudio==0.11.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
torchvision==0.12.0+cu113
|
||||
# via -r examples/libero/requirements.in
|
||||
tqdm==4.67.1
|
||||
# via -r examples/libero/requirements.in
|
||||
typeguard==4.4.0
|
||||
# via tyro
|
||||
typing-extensions==4.12.2
|
||||
# via
|
||||
# etils
|
||||
# rich
|
||||
# torch
|
||||
# torchvision
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.2
|
||||
# via -r examples/libero/requirements.in
|
||||
urllib3==2.2.3
|
||||
# via requests
|
||||
zipp==3.20.2
|
||||
# via
|
||||
# etils
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
134
policy/openpi-InternData-A1/examples/policy_records.ipynb
Normal file
134
policy/openpi-InternData-A1/examples/policy_records.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pathlib\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"record_path = pathlib.Path(\"../policy_records\")\n",
|
||||
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
|
||||
"\n",
|
||||
"records = []\n",
|
||||
"for i in range(num_steps):\n",
|
||||
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
|
||||
" records.append(record)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"length of records\", len(records))\n",
|
||||
"print(\"keys in records\", records[0].keys())\n",
|
||||
"\n",
|
||||
"for k in records[0]:\n",
|
||||
" print(f\"{k} shape: {records[0][k].shape}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_image(step: int, idx: int = 0):\n",
|
||||
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
|
||||
" return img[idx].transpose(1, 2, 0)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def show_image(step: int, idx_lst: list[int]):\n",
|
||||
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
|
||||
" return Image.fromarray(np.hstack(imgs))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for i in range(2):\n",
|
||||
" display(show_image(i, [0]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def get_axis(name, axis):\n",
|
||||
" return np.array([record[name][axis] for record in records])\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# qpos is [..., 14] of type float:\n",
|
||||
"# 0-5: left arm joint angles\n",
|
||||
"# 6: left arm gripper\n",
|
||||
"# 7-12: right arm joint angles\n",
|
||||
"# 13: right arm gripper\n",
|
||||
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_data():\n",
|
||||
" cur_dim = 0\n",
|
||||
" in_data = {}\n",
|
||||
" out_data = {}\n",
|
||||
" for name, dim_size in names:\n",
|
||||
" for i in range(dim_size):\n",
|
||||
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
|
||||
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
|
||||
" cur_dim += 1\n",
|
||||
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"in_data, out_data = make_data()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for name in in_data.columns:\n",
|
||||
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
|
||||
" data.plot()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
# Dockerfile for the simple client.
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t simple_client -f examples/simple_client/Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
|
||||
|
||||
FROM python:3.7-slim
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Copy the requirements files so we can install dependencies.
|
||||
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
|
||||
# This strategy is best for development-style usage.
|
||||
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
|
||||
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
|
||||
|
||||
# Install python dependencies.
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
|
||||
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
|
||||
|
||||
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"
|
||||
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
30
policy/openpi-InternData-A1/examples/simple_client/README.md
Normal file
@@ -0,0 +1,30 @@
|
||||
# Simple Client
|
||||
|
||||
A minimal client that sends observations to the server and prints the inference rate.
|
||||
|
||||
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --help
|
||||
```
|
||||
|
||||
## With Docker
|
||||
|
||||
```bash
|
||||
export SERVER_ARGS="--env ALOHA_SIM"
|
||||
docker compose -f examples/simple_client/compose.yml up --build
|
||||
```
|
||||
|
||||
## Without Docker
|
||||
|
||||
Terminal window 1:
|
||||
|
||||
```bash
|
||||
uv run examples/simple_client/main.py --env DROID
|
||||
```
|
||||
|
||||
Terminal window 2:
|
||||
|
||||
```bash
|
||||
uv run scripts/serve_policy.py --env DROID
|
||||
```
|
||||
@@ -0,0 +1,42 @@
|
||||
# Run with:
|
||||
# docker compose -f examples/simple_client/compose.yml up --build
|
||||
services:
|
||||
runtime:
|
||||
image: simple_client
|
||||
depends_on:
|
||||
- openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/simple_client/Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
187
policy/openpi-InternData-A1/examples/simple_client/main.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from openpi_client import websocket_client_policy as _websocket_client_policy
|
||||
import polars as pl
|
||||
import rich
|
||||
import tqdm
|
||||
import tyro
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Command line arguments."""
|
||||
|
||||
# Host and port to connect to the server.
|
||||
host: str = "0.0.0.0"
|
||||
# Port to connect to the server. If None, the server will use the default port.
|
||||
port: int | None = 8000
|
||||
# API key to use for the server.
|
||||
api_key: str | None = None
|
||||
# Number of steps to run the policy for.
|
||||
num_steps: int = 20
|
||||
# Path to save the timings to a parquet file. (e.g., timing.parquet)
|
||||
timing_file: pathlib.Path | None = None
|
||||
# Environment to run the policy in.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
|
||||
class TimingRecorder:
|
||||
"""Records timing measurements for different keys."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._timings: dict[str, list[float]] = {}
|
||||
|
||||
def record(self, key: str, time_ms: float) -> None:
|
||||
"""Record a timing measurement for the given key."""
|
||||
if key not in self._timings:
|
||||
self._timings[key] = []
|
||||
self._timings[key].append(time_ms)
|
||||
|
||||
def get_stats(self, key: str) -> dict[str, float]:
|
||||
"""Get statistics for the given key."""
|
||||
times = self._timings[key]
|
||||
return {
|
||||
"mean": float(np.mean(times)),
|
||||
"std": float(np.std(times)),
|
||||
"p25": float(np.quantile(times, 0.25)),
|
||||
"p50": float(np.quantile(times, 0.50)),
|
||||
"p75": float(np.quantile(times, 0.75)),
|
||||
"p90": float(np.quantile(times, 0.90)),
|
||||
"p95": float(np.quantile(times, 0.95)),
|
||||
"p99": float(np.quantile(times, 0.99)),
|
||||
}
|
||||
|
||||
def print_all_stats(self) -> None:
|
||||
"""Print statistics for all keys in a concise format."""
|
||||
|
||||
table = rich.table.Table(
|
||||
title="[bold blue]Timing Statistics[/bold blue]",
|
||||
show_header=True,
|
||||
header_style="bold white",
|
||||
border_style="blue",
|
||||
title_justify="center",
|
||||
)
|
||||
|
||||
# Add metric column with custom styling
|
||||
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
|
||||
|
||||
# Add statistical columns with consistent styling
|
||||
stat_columns = [
|
||||
("Mean", "yellow", "mean"),
|
||||
("Std", "yellow", "std"),
|
||||
("P25", "magenta", "p25"),
|
||||
("P50", "magenta", "p50"),
|
||||
("P75", "magenta", "p75"),
|
||||
("P90", "magenta", "p90"),
|
||||
("P95", "magenta", "p95"),
|
||||
("P99", "magenta", "p99"),
|
||||
]
|
||||
|
||||
for name, style, _ in stat_columns:
|
||||
table.add_column(name, justify="right", style=style, no_wrap=True)
|
||||
|
||||
# Add rows for each metric with formatted values
|
||||
for key in sorted(self._timings.keys()):
|
||||
stats = self.get_stats(key)
|
||||
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
|
||||
table.add_row(key, *values)
|
||||
|
||||
# Print with custom console settings
|
||||
console = rich.console.Console(width=None, highlight=True)
|
||||
console.print(table)
|
||||
|
||||
def write_parquet(self, path: pathlib.Path) -> None:
|
||||
"""Save the timings to a parquet file."""
|
||||
logger.info(f"Writing timings to {path}")
|
||||
frame = pl.DataFrame(self._timings)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
frame.write_parquet(path)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
obs_fn = {
|
||||
EnvMode.ALOHA: _random_observation_aloha,
|
||||
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
||||
EnvMode.DROID: _random_observation_droid,
|
||||
EnvMode.LIBERO: _random_observation_libero,
|
||||
}[args.env]
|
||||
|
||||
policy = _websocket_client_policy.WebsocketClientPolicy(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
api_key=args.api_key,
|
||||
)
|
||||
logger.info(f"Server metadata: {policy.get_server_metadata()}")
|
||||
|
||||
# Send a few observations to make sure the model is loaded.
|
||||
for _ in range(2):
|
||||
policy.infer(obs_fn())
|
||||
|
||||
timing_recorder = TimingRecorder()
|
||||
|
||||
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
|
||||
inference_start = time.time()
|
||||
action = policy.infer(obs_fn())
|
||||
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
|
||||
for key, value in action.get("server_timing", {}).items():
|
||||
timing_recorder.record(f"server_{key}", value)
|
||||
for key, value in action.get("policy_timing", {}).items():
|
||||
timing_recorder.record(f"policy_{key}", value)
|
||||
|
||||
timing_recorder.print_all_stats()
|
||||
|
||||
if args.timing_file is not None:
|
||||
timing_recorder.write_parquet(args.timing_file)
|
||||
|
||||
|
||||
def _random_observation_aloha() -> dict:
|
||||
return {
|
||||
"state": np.ones((14,)),
|
||||
"images": {
|
||||
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
||||
},
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_droid() -> dict:
|
||||
return {
|
||||
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/joint_position": np.random.rand(7),
|
||||
"observation/gripper_position": np.random.rand(1),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
def _random_observation_libero() -> dict:
|
||||
return {
|
||||
"observation/state": np.random.rand(8),
|
||||
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
||||
"prompt": "do something",
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main(tyro.cli(Args))
|
||||
@@ -0,0 +1,5 @@
|
||||
numpy>=1.22.4,<2.0.0
|
||||
rich
|
||||
tqdm
|
||||
tyro
|
||||
polars
|
||||
@@ -0,0 +1,30 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
|
||||
docstring-parser==0.16
|
||||
# via tyro
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
numpy==1.26.4
|
||||
# via -r examples/simple_client/requirements.in
|
||||
polars==1.30.0
|
||||
# via -r examples/simple_client/requirements.in
|
||||
pygments==2.19.1
|
||||
# via rich
|
||||
rich==14.0.0
|
||||
# via
|
||||
# -r examples/simple_client/requirements.in
|
||||
# tyro
|
||||
shtab==1.7.2
|
||||
# via tyro
|
||||
tqdm==4.67.1
|
||||
# via -r examples/simple_client/requirements.in
|
||||
typeguard==4.4.2
|
||||
# via tyro
|
||||
typing-extensions==4.13.2
|
||||
# via
|
||||
# typeguard
|
||||
# tyro
|
||||
tyro==0.9.22
|
||||
# via -r examples/simple_client/requirements.in
|
||||
142
policy/openpi-InternData-A1/examples/ur5/README.md
Normal file
142
policy/openpi-InternData-A1/examples/ur5/README.md
Normal file
@@ -0,0 +1,142 @@
|
||||
# UR5 Example
|
||||
|
||||
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
|
||||
|
||||
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Inputs(transforms.DataTransformFn):
|
||||
|
||||
model_type: _model.ModelType = _model.ModelType.PI0
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# First, concatenate the joints and gripper into the state vector.
|
||||
state = np.concatenate([data["joints"], data["gripper"]])
|
||||
|
||||
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
|
||||
# stores as float32 (C,H,W), gets skipped for policy inference.
|
||||
base_image = _parse_image(data["base_rgb"])
|
||||
wrist_image = _parse_image(data["wrist_rgb"])
|
||||
|
||||
# Create inputs dict.
|
||||
inputs = {
|
||||
"state": state,
|
||||
"image": {
|
||||
"base_0_rgb": base_image,
|
||||
"left_wrist_0_rgb": wrist_image,
|
||||
# Since there is no right wrist, replace with zeros
|
||||
"right_wrist_0_rgb": np.zeros_like(base_image),
|
||||
},
|
||||
"image_mask": {
|
||||
"base_0_rgb": np.True_,
|
||||
"left_wrist_0_rgb": np.True_,
|
||||
# Since the "slot" for the right wrist is not used, this mask is set
|
||||
# to False
|
||||
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
|
||||
},
|
||||
}
|
||||
|
||||
if "actions" in data:
|
||||
inputs["actions"] = data["actions"]
|
||||
|
||||
# Pass the prompt (aka language instruction) to the model.
|
||||
if "prompt" in data:
|
||||
inputs["prompt"] = data["prompt"]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class UR5Outputs(transforms.DataTransformFn):
|
||||
|
||||
def __call__(self, data: dict) -> dict:
|
||||
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
|
||||
return {"actions": np.asarray(data["actions"][:, :7])}
|
||||
|
||||
```
|
||||
|
||||
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
|
||||
|
||||
```python
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LeRobotUR5DataConfig(DataConfigFactory):
|
||||
|
||||
@override
|
||||
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
|
||||
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
|
||||
repack_transform = _transforms.Group(
|
||||
inputs=[
|
||||
_transforms.RepackTransform(
|
||||
{
|
||||
"base_rgb": "image",
|
||||
"wrist_rgb": "wrist_image",
|
||||
"joints": "joints",
|
||||
"gripper": "gripper",
|
||||
"prompt": "prompt",
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# These transforms are the ones we wrote earlier.
|
||||
data_transforms = _transforms.Group(
|
||||
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
|
||||
outputs=[UR5Outputs()],
|
||||
)
|
||||
|
||||
# Convert absolute actions to delta actions.
|
||||
# By convention, we do not convert the gripper action (7th dimension).
|
||||
delta_action_mask = _transforms.make_bool_mask(6, -1)
|
||||
data_transforms = data_transforms.push(
|
||||
inputs=[_transforms.DeltaActions(delta_action_mask)],
|
||||
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
|
||||
)
|
||||
|
||||
# Model transforms include things like tokenizing the prompt and action targets
|
||||
# You do not need to change anything here for your own dataset.
|
||||
model_transforms = ModelTransformFactory()(model_config)
|
||||
|
||||
# We return all data transforms for training and inference. No need to change anything here.
|
||||
return dataclasses.replace(
|
||||
self.create_base_config(assets_dirs),
|
||||
repack_transforms=repack_transform,
|
||||
data_transforms=data_transforms,
|
||||
model_transforms=model_transforms,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
|
||||
|
||||
```python
|
||||
TrainConfig(
|
||||
name="pi0_ur5",
|
||||
model=pi0.Pi0Config(),
|
||||
data=LeRobotUR5DataConfig(
|
||||
repo_id="your_username/ur5_dataset",
|
||||
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
|
||||
# Reloading normalization stats can help transfer pre-trained models to new environments.
|
||||
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
|
||||
assets=AssetsConfig(
|
||||
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
|
||||
asset_id="ur5e",
|
||||
),
|
||||
base_config=DataConfig(
|
||||
# This flag determines whether we load the prompt (i.e. the task instruction) from the
|
||||
# ``task`` field in the LeRobot dataset. The recommended setting is True.
|
||||
prompt_from_task=True,
|
||||
),
|
||||
),
|
||||
# Load the pi0 base model checkpoint.
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
|
||||
num_train_steps=30_000,
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
[project]
|
||||
name = "openpi-client"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.7"
|
||||
dependencies = [
|
||||
"dm-tree>=0.1.8",
|
||||
"msgpack>=1.0.5",
|
||||
"numpy>=1.22.4,<2.0.0",
|
||||
"pillow>=9.0.0",
|
||||
"tree>=0.2.4",
|
||||
"websockets>=11.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.uv]
|
||||
dev-dependencies = ["pytest>=8.3.4"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py37"
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = "0.1.0"
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import tree
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
|
||||
|
||||
class ActionChunkBroker(_base_policy.BasePolicy):
|
||||
"""Wraps a policy to return action chunks one-at-a-time.
|
||||
|
||||
Assumes that the first dimension of all action fields is the chunk size.
|
||||
|
||||
A new inference call to the inner policy is only made when the current
|
||||
list of chunks is exhausted.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
|
||||
self._policy = policy
|
||||
self._action_horizon = action_horizon
|
||||
self._cur_step: int = 0
|
||||
|
||||
self._last_results: Dict[str, np.ndarray] | None = None
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
if self._last_results is None:
|
||||
self._last_results = self._policy.infer(obs)
|
||||
self._cur_step = 0
|
||||
|
||||
def slicer(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return x[self._cur_step, ...]
|
||||
else:
|
||||
return x
|
||||
|
||||
results = tree.map_structure(slicer, self._last_results)
|
||||
self._cur_step += 1
|
||||
|
||||
if self._cur_step >= self._action_horizon:
|
||||
self._last_results = None
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
self._last_results = None
|
||||
self._cur_step = 0
|
||||
@@ -0,0 +1,12 @@
|
||||
import abc
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class BasePolicy(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def infer(self, obs: Dict) -> Dict:
|
||||
"""Infer actions from observations."""
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the policy to its initial state."""
|
||||
pass
|
||||
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
|
||||
"""Converts an image to uint8 if it is a float image.
|
||||
|
||||
This is important for reducing the size of the image when sending it over the network.
|
||||
"""
|
||||
if np.issubdtype(img.dtype, np.floating):
|
||||
img = (255 * img).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
|
||||
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
|
||||
|
||||
Args:
|
||||
images: A batch of images in [..., height, width, channel] format.
|
||||
height: The target height of the image.
|
||||
width: The target width of the image.
|
||||
method: The interpolation method to use. Default is bilinear.
|
||||
|
||||
Returns:
|
||||
The resized images in [..., height, width, channel].
|
||||
"""
|
||||
# If the images are already the correct size, return them as is.
|
||||
if images.shape[-3:-1] == (height, width):
|
||||
return images
|
||||
|
||||
original_shape = images.shape
|
||||
|
||||
images = images.reshape(-1, *original_shape[-3:])
|
||||
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
|
||||
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
|
||||
|
||||
|
||||
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
|
||||
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
|
||||
width without distortion by padding with zeros.
|
||||
|
||||
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
|
||||
"""
|
||||
cur_width, cur_height = image.size
|
||||
if cur_width == width and cur_height == height:
|
||||
return image # No need to resize if the image is already the correct size.
|
||||
|
||||
ratio = max(cur_width / width, cur_height / height)
|
||||
resized_height = int(cur_height / ratio)
|
||||
resized_width = int(cur_width / ratio)
|
||||
resized_image = image.resize((resized_width, resized_height), resample=method)
|
||||
|
||||
zero_image = Image.new(resized_image.mode, (width, height), 0)
|
||||
pad_height = max(0, int((height - resized_height) / 2))
|
||||
pad_width = max(0, int((width - resized_width) / 2))
|
||||
zero_image.paste(resized_image, (pad_width, pad_height))
|
||||
assert zero_image.size == (width, height)
|
||||
return zero_image
|
||||
@@ -0,0 +1,37 @@
|
||||
import numpy as np
|
||||
|
||||
import openpi_client.image_tools as image_tools
|
||||
|
||||
|
||||
def test_resize_with_pad_shapes():
|
||||
# Test case 1: Resize image with larger dimensions
|
||||
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
|
||||
height = 20
|
||||
width = 20
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (2, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 2: Resize image with smaller dimensions
|
||||
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
|
||||
height = 15
|
||||
width = 15
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (3, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with the same dimensions
|
||||
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
|
||||
height = 50
|
||||
width = 50
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
|
||||
# Test case 3: Resize image with odd-numbered padding
|
||||
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
|
||||
height = 60
|
||||
width = 80
|
||||
resized_images = image_tools.resize_with_pad(images, height, width)
|
||||
assert resized_images.shape == (1, height, width, 3)
|
||||
assert np.all(resized_images == 0)
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Adds NumPy array support to msgpack.
|
||||
|
||||
msgpack is good for (de)serializing data over a network for multiple reasons:
|
||||
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
|
||||
- msgpack is widely used and has good cross-language support
|
||||
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
|
||||
languages like Python and JavaScript
|
||||
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
|
||||
than pickle for serializing large arrays using the below strategy
|
||||
|
||||
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
|
||||
that it falls back to pickle for object arrays.
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
import msgpack
|
||||
import numpy as np
|
||||
|
||||
|
||||
def pack_array(obj):
|
||||
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
|
||||
raise ValueError(f"Unsupported dtype: {obj.dtype}")
|
||||
|
||||
if isinstance(obj, np.ndarray):
|
||||
return {
|
||||
b"__ndarray__": True,
|
||||
b"data": obj.tobytes(),
|
||||
b"dtype": obj.dtype.str,
|
||||
b"shape": obj.shape,
|
||||
}
|
||||
|
||||
if isinstance(obj, np.generic):
|
||||
return {
|
||||
b"__npgeneric__": True,
|
||||
b"data": obj.item(),
|
||||
b"dtype": obj.dtype.str,
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def unpack_array(obj):
|
||||
if b"__ndarray__" in obj:
|
||||
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
|
||||
|
||||
if b"__npgeneric__" in obj:
|
||||
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
Packer = functools.partial(msgpack.Packer, default=pack_array)
|
||||
packb = functools.partial(msgpack.packb, default=pack_array)
|
||||
|
||||
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
|
||||
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)
|
||||
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tree
|
||||
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
def _check(expected, actual):
|
||||
if isinstance(expected, np.ndarray):
|
||||
assert expected.shape == actual.shape
|
||||
assert expected.dtype == actual.dtype
|
||||
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
|
||||
else:
|
||||
assert expected == actual
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data",
|
||||
[
|
||||
1, # int
|
||||
1.0, # float
|
||||
"hello", # string
|
||||
np.bool_(True), # boolean scalar
|
||||
np.array([1, 2, 3])[0], # int scalar
|
||||
np.str_("asdf"), # string scalar
|
||||
[1, 2, 3], # list
|
||||
{"key": "value"}, # dict
|
||||
{"key": [1, 2, 3]}, # nested dict
|
||||
np.array(1.0), # 0D array
|
||||
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
|
||||
np.array(["asdf", "qwer"]), # string array
|
||||
np.array([True, False]), # boolean array
|
||||
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
|
||||
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
|
||||
np.array([np.nan, np.inf, -np.inf]), # special float values
|
||||
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
|
||||
[np.array([1, 2]), np.array([3, 4])], # list of arrays
|
||||
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
|
||||
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
|
||||
],
|
||||
)
|
||||
def test_pack_unpack(data):
|
||||
packed = msgpack_numpy.packb(data)
|
||||
unpacked = msgpack_numpy.unpackb(packed)
|
||||
tree.map_structure(_check, data, unpacked)
|
||||
@@ -0,0 +1,17 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
|
||||
|
||||
Agents receive observations about the state of the world, and return actions
|
||||
to take in response.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
"""Query the agent for the next action."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the agent to its initial state."""
|
||||
@@ -0,0 +1,18 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client.runtime import agent as _agent
|
||||
|
||||
|
||||
class PolicyAgent(_agent.Agent):
|
||||
"""An agent that uses a policy to determine actions."""
|
||||
|
||||
def __init__(self, policy: _base_policy.BasePolicy) -> None:
|
||||
self._policy = policy
|
||||
|
||||
@override
|
||||
def get_action(self, observation: dict) -> dict:
|
||||
return self._policy.infer(observation)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._policy.reset()
|
||||
@@ -0,0 +1,32 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Environment(abc.ABC):
|
||||
"""An Environment represents the robot and the environment it inhabits.
|
||||
|
||||
The primary contract of environments is that they can be queried for observations
|
||||
about their state, and have actions applied to them to change that state.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the environment to its initial state.
|
||||
|
||||
This will be called once before starting each episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_episode_complete(self) -> bool:
|
||||
"""Allow the environment to signal that the episode is complete.
|
||||
|
||||
This will be called after each step. It should return `True` if the episode is
|
||||
complete (either successfully or unsuccessfully), and `False` otherwise.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_observation(self) -> dict:
|
||||
"""Query the environment for the current state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def apply_action(self, action: dict) -> None:
|
||||
"""Take an action in the environment."""
|
||||
@@ -0,0 +1,92 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from openpi_client.runtime import agent as _agent
|
||||
from openpi_client.runtime import environment as _environment
|
||||
from openpi_client.runtime import subscriber as _subscriber
|
||||
|
||||
|
||||
class Runtime:
|
||||
"""The core module orchestrating interactions between key components of the system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
environment: _environment.Environment,
|
||||
agent: _agent.Agent,
|
||||
subscribers: list[_subscriber.Subscriber],
|
||||
max_hz: float = 0,
|
||||
num_episodes: int = 1,
|
||||
max_episode_steps: int = 0,
|
||||
) -> None:
|
||||
self._environment = environment
|
||||
self._agent = agent
|
||||
self._subscribers = subscribers
|
||||
self._max_hz = max_hz
|
||||
self._num_episodes = num_episodes
|
||||
self._max_episode_steps = max_episode_steps
|
||||
|
||||
self._in_episode = False
|
||||
self._episode_steps = 0
|
||||
|
||||
def run(self) -> None:
|
||||
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
|
||||
for _ in range(self._num_episodes):
|
||||
self._run_episode()
|
||||
|
||||
# Final reset, this is important for real environments to move the robot to its home position.
|
||||
self._environment.reset()
|
||||
|
||||
def run_in_new_thread(self) -> threading.Thread:
|
||||
"""Runs the runtime loop in a new thread."""
|
||||
thread = threading.Thread(target=self.run)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
def mark_episode_complete(self) -> None:
|
||||
"""Marks the end of an episode."""
|
||||
self._in_episode = False
|
||||
|
||||
def _run_episode(self) -> None:
|
||||
"""Runs a single episode."""
|
||||
logging.info("Starting episode...")
|
||||
self._environment.reset()
|
||||
self._agent.reset()
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_start()
|
||||
|
||||
self._in_episode = True
|
||||
self._episode_steps = 0
|
||||
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
|
||||
last_step_time = time.time()
|
||||
|
||||
while self._in_episode:
|
||||
self._step()
|
||||
self._episode_steps += 1
|
||||
|
||||
# Sleep to maintain the desired frame rate
|
||||
now = time.time()
|
||||
dt = now - last_step_time
|
||||
if dt < step_time:
|
||||
time.sleep(step_time - dt)
|
||||
last_step_time = time.time()
|
||||
else:
|
||||
last_step_time = now
|
||||
|
||||
logging.info("Episode completed.")
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_episode_end()
|
||||
|
||||
def _step(self) -> None:
|
||||
"""A single step of the runtime loop."""
|
||||
observation = self._environment.get_observation()
|
||||
action = self._agent.get_action(observation)
|
||||
self._environment.apply_action(action)
|
||||
|
||||
for subscriber in self._subscribers:
|
||||
subscriber.on_step(observation, action)
|
||||
|
||||
if self._environment.is_episode_complete() or (
|
||||
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
|
||||
):
|
||||
self.mark_episode_complete()
|
||||
@@ -0,0 +1,20 @@
|
||||
import abc
|
||||
|
||||
|
||||
class Subscriber(abc.ABC):
|
||||
"""Subscribes to events in the runtime.
|
||||
|
||||
Subscribers can be used to save data, visualize, etc.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_start(self) -> None:
|
||||
"""Called when an episode starts."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_step(self, observation: dict, action: dict) -> None:
|
||||
"""Append a step to the episode."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def on_episode_end(self) -> None:
|
||||
"""Called when an episode ends."""
|
||||
@@ -0,0 +1,55 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from typing_extensions import override
|
||||
import websockets.sync.client
|
||||
|
||||
from openpi_client import base_policy as _base_policy
|
||||
from openpi_client import msgpack_numpy
|
||||
|
||||
|
||||
class WebsocketClientPolicy(_base_policy.BasePolicy):
|
||||
"""Implements the Policy interface by communicating with a server over websocket.
|
||||
|
||||
See WebsocketPolicyServer for a corresponding server implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
|
||||
self._uri = f"ws://{host}"
|
||||
if port is not None:
|
||||
self._uri += f":{port}"
|
||||
self._packer = msgpack_numpy.Packer()
|
||||
self._api_key = api_key
|
||||
self._ws, self._server_metadata = self._wait_for_server()
|
||||
|
||||
def get_server_metadata(self) -> Dict:
|
||||
return self._server_metadata
|
||||
|
||||
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
|
||||
logging.info(f"Waiting for server at {self._uri}...")
|
||||
while True:
|
||||
try:
|
||||
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
|
||||
conn = websockets.sync.client.connect(
|
||||
self._uri, compression=None, max_size=None, additional_headers=headers
|
||||
)
|
||||
metadata = msgpack_numpy.unpackb(conn.recv())
|
||||
return conn, metadata
|
||||
except ConnectionRefusedError:
|
||||
logging.info("Still waiting for server...")
|
||||
time.sleep(5)
|
||||
|
||||
@override
|
||||
def infer(self, obs: Dict) -> Dict: # noqa: UP006
|
||||
data = self._packer.pack(obs)
|
||||
self._ws.send(data)
|
||||
response = self._ws.recv()
|
||||
if isinstance(response, str):
|
||||
# we're expecting bytes; if the server sends a string, it's an error.
|
||||
raise RuntimeError(f"Error in inference server:\n{response}")
|
||||
return msgpack_numpy.unpackb(response)
|
||||
|
||||
@override
|
||||
def reset(self) -> None:
|
||||
pass
|
||||
136
policy/openpi-InternData-A1/pyproject.toml
Normal file
136
policy/openpi-InternData-A1/pyproject.toml
Normal file
@@ -0,0 +1,136 @@
|
||||
[project]
|
||||
name = "openpi"
|
||||
version = "0.1.0"
|
||||
description = "Physical Intelligence open source repo"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"augmax>=0.3.4",
|
||||
"dm-tree>=0.1.8",
|
||||
"einops>=0.8.0",
|
||||
"equinox>=0.11.8",
|
||||
"flatbuffers>=24.3.25",
|
||||
"flax==0.10.2",
|
||||
"fsspec[gcs]>=2024.6.0",
|
||||
"gym-aloha>=0.1.1",
|
||||
"imageio>=2.36.1",
|
||||
"jax[cuda12]==0.5.3",
|
||||
"jaxtyping==0.2.36",
|
||||
"ml_collections==1.0.0",
|
||||
"numpy>=1.22.4,<2.0.0",
|
||||
"numpydantic>=1.6.6",
|
||||
"opencv-python>=4.10.0.84",
|
||||
"openpi-client",
|
||||
"orbax-checkpoint==0.11.13",
|
||||
"pillow>=11.0.0",
|
||||
"sentencepiece>=0.2.0",
|
||||
"torch==2.7.1",
|
||||
"tqdm-loggable>=0.2",
|
||||
"typing-extensions>=4.12.2",
|
||||
"tyro>=0.9.5",
|
||||
"wandb>=0.19.1",
|
||||
"filelock>=3.16.1",
|
||||
"beartype==0.19.0",
|
||||
"treescope>=0.1.7",
|
||||
"transformers==4.53.2",
|
||||
"rich>=14.0.0",
|
||||
"polars>=1.30.0",
|
||||
]
|
||||
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/Physical-Intelligence/openpi"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.3.4",
|
||||
"ruff>=0.8.6",
|
||||
"pre-commit>=4.0.1",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipywidgets>=8.1.5",
|
||||
"matplotlib>=3.10.0",
|
||||
"pynvml>=12.0.0",
|
||||
]
|
||||
rlds = [
|
||||
"dlimp",
|
||||
"tensorflow-cpu==2.15.0",
|
||||
"tensorflow-datasets==4.9.9",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
|
||||
|
||||
[tool.uv.sources]
|
||||
openpi-client = { workspace = true }
|
||||
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
|
||||
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/*"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py311"
|
||||
extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
# https://docs.astral.sh/ruff/rules/
|
||||
select = [
|
||||
"B",
|
||||
"C4",
|
||||
"DTZ",
|
||||
"E4",
|
||||
"E7",
|
||||
"E9",
|
||||
"F",
|
||||
"FBT",
|
||||
"FURB",
|
||||
"I",
|
||||
"ICN",
|
||||
"ISC",
|
||||
"LOG",
|
||||
"N",
|
||||
"PD",
|
||||
"PERF",
|
||||
"PIE",
|
||||
"PLC",
|
||||
"PLE",
|
||||
"PLR1",
|
||||
"PLR5",
|
||||
"PLW",
|
||||
"PT",
|
||||
"Q",
|
||||
"RET",
|
||||
"RUF",
|
||||
"SIM",
|
||||
"SLF",
|
||||
"T10",
|
||||
"T20",
|
||||
"UP",
|
||||
"W",
|
||||
]
|
||||
ignore = [
|
||||
"F722", # Conflicts with array typing.
|
||||
"T201", # We use print statements.
|
||||
"PD008", # Lots of false positives.
|
||||
"ISC001", # Disabling to support ruff format.
|
||||
"LOG015", # Use logger.info.
|
||||
]
|
||||
unfixable = [
|
||||
"B905", # Fix defaults to strict=False, which is not what we want.
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
force-single-line = true
|
||||
force-sort-within-sections = true
|
||||
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
|
||||
known-third-party = ["wandb"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = ["manual: should be run manually."]
|
||||
testpaths = ["src", "scripts", "packages"]
|
||||
@@ -0,0 +1,36 @@
|
||||
augmax>=0.3.4
|
||||
dm-tree>=0.1.8
|
||||
einops>=0.8.0
|
||||
equinox>=0.11.8
|
||||
flatbuffers>=24.3.25
|
||||
flax==0.10.2
|
||||
fsspec[gcs]>=2024.6.0
|
||||
gym-aloha>=0.1.1
|
||||
imageio>=2.36.1
|
||||
jax[cuda12]==0.5.3
|
||||
jaxtyping==0.2.36
|
||||
ml_collections==1.0.0
|
||||
numpy>=1.22.4,<2.0.0
|
||||
numpydantic>=1.6.6
|
||||
opencv-python>=4.10.0.84
|
||||
orbax-checkpoint==0.11.13
|
||||
pillow>=11.0.0
|
||||
sentencepiece>=0.2.0
|
||||
tqdm-loggable>=0.2
|
||||
typing-extensions>=4.12.2
|
||||
tyro>=0.9.5
|
||||
wandb>=0.19.1
|
||||
filelock>=3.16.1
|
||||
beartype==0.19.0
|
||||
treescope>=0.1.7
|
||||
transformers==4.53.2
|
||||
rich>=14.0.0
|
||||
polars>=1.30.0
|
||||
ml-dtypes==0.5.3
|
||||
tensorstore==0.1.74
|
||||
# tensorflow==2.20.0
|
||||
tensorflow-datasets==4.9.9
|
||||
lmdb==1.7.3
|
||||
pytest==8.4.1
|
||||
nvidia-cudnn-cu12==9.10.2.21
|
||||
# dlimp
|
||||
0
policy/openpi-InternData-A1/scripts/__init__.py
Normal file
0
policy/openpi-InternData-A1/scripts/__init__.py
Normal file
218
policy/openpi-InternData-A1/scripts/compute_norm_stats_real.py
Normal file
218
policy/openpi-InternData-A1/scripts/compute_norm_stats_real.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Compute normalization statistics for real-world tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for a given real-world task. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiLeRobotReala2dDataConfig, MultiLeRobotRealArxLift2DataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, robot_name, task_name, save_path):
|
||||
if robot_name == "lift2" or robot_name == "split_aloha" or robot_name == "acone":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiLeRobotRealArxLift2DataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif robot_name == "genie1":
|
||||
config = TrainConfig(
|
||||
name="genie1",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiLeRobotReala2dDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"joint": "observation.states.joint.position",
|
||||
"gripper": "observation.states.effector.position",
|
||||
},
|
||||
"action_dict": {
|
||||
"joint": "actions.joint.position",
|
||||
"gripper": "actions.effector.position",
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, robot_name, task_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task_path", type=str, default="data/InternData-A1/real/genie1/Put_the_pen_from_the_table_into_the_pen_holder/*")
|
||||
parser.add_argument("--robot_name", type=str, default="genie1")
|
||||
parser.add_argument("--save_path", type=str, default="stats/real")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
dataset_path=args.task_path
|
||||
save_path = args.save_path
|
||||
parts = dataset_path.split("/")
|
||||
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
|
||||
if robot_idx is None:
|
||||
raise ValueError(
|
||||
f"Cannot find robot name in path. Expected {args.robot_name}, "
|
||||
f"but got path: {dataset_path}"
|
||||
)
|
||||
|
||||
if robot_idx + 1 >= len(parts):
|
||||
raise ValueError(
|
||||
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
|
||||
)
|
||||
robot_name = parts[robot_idx]
|
||||
task_name = parts[robot_idx + 1]
|
||||
try:
|
||||
main(dataset_path, robot_name, task_name, save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
314
policy/openpi-InternData-A1/scripts/compute_norm_stats_sim.py
Normal file
314
policy/openpi-InternData-A1/scripts/compute_norm_stats_sim.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""Compute normalization statistics for interndata-a1 sim tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for interndata-a1 sim tasks. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config assets directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiSimGenieDataConfig, MultiSimSplitAlohaDataConfig, MultiSimFrankaDataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, task_category, robot_name, task_name, collect_name, save_path):
|
||||
if robot_name == "lift2" or robot_name == "split_aloha":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimSplitAlohaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif robot_name == "genie1":
|
||||
config = TrainConfig(
|
||||
name="genie1",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimGenieDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
elif "franka" in robot_name:
|
||||
config = TrainConfig(
|
||||
name="franka",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSimFrankaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=True,
|
||||
gripper_aug_config={
|
||||
"gripper_action_keys": ["actions.gripper.openness"],
|
||||
"gripper_dim": -1,
|
||||
"gripper_threshold_method": "std_multiplier",
|
||||
"gripper_threshold_multiplier": 1.0,
|
||||
"gripper_min_threshold": 0.001,
|
||||
"gripper_max_threshold": 1.0,
|
||||
},
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"joint_position": "states.joint.position",
|
||||
"gripper_pose": "states.gripper.pose",
|
||||
"gripper_position": "states.gripper.position",
|
||||
},
|
||||
"action_dict": {
|
||||
"gripper_pose": "actions.gripper.pose",
|
||||
"gripper_position": "actions.gripper.position",
|
||||
"gripper_openness": "actions.gripper.openness",
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, task_category, robot_name, task_name, collect_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
if step_id > 10000:
|
||||
break
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--root_data_dir", type=str, default="data/InternData-A1/sim")
|
||||
parser.add_argument("--task_category", type=str, default="pick_and_place_tasks")
|
||||
parser.add_argument("--save_path", type=str, default="stats/sim")
|
||||
parser.add_argument("--start_ratio", type=float, default=0.0)
|
||||
parser.add_argument("--end_ratio", type=float, default=1)
|
||||
args, unknown = parser.parse_known_args()
|
||||
root_data_dir = os.path.join(args.root_data_dir, args.task_category)
|
||||
|
||||
dataset_paths = glob.glob(os.path.join(root_data_dir, "*", "*"))
|
||||
dataset_paths.sort()
|
||||
valid_paths = [
|
||||
p for p in dataset_paths
|
||||
if check_lerobot_repo(p)
|
||||
]
|
||||
|
||||
start_idx = int(len(valid_paths) * args.start_ratio)
|
||||
end_idx = int(len(valid_paths) * args.end_ratio) + 1
|
||||
valid_paths = valid_paths[start_idx:end_idx]
|
||||
for dataset_path in tqdm.tqdm(valid_paths):
|
||||
task_category = dataset_path.split('/')[-3]
|
||||
robot_name = dataset_path.split('/')[-2]
|
||||
task_name = dataset_path.split('/')[-1]
|
||||
collect_name = ""
|
||||
try:
|
||||
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
dataset_paths_w_subtask = glob.glob(os.path.join(root_data_dir, "*", "*","*"))
|
||||
dataset_paths_w_subtask.sort()
|
||||
valid_paths_w_subtask = [
|
||||
p for p in dataset_paths_w_subtask
|
||||
if check_lerobot_repo(p)
|
||||
]
|
||||
start_idx = int(len(valid_paths_w_subtask) * args.start_ratio)
|
||||
end_idx = int(len(valid_paths_w_subtask) * args.end_ratio) + 1
|
||||
valid_paths_w_subtask = valid_paths_w_subtask[start_idx:end_idx]
|
||||
for dataset_path in tqdm.tqdm(valid_paths_w_subtask):
|
||||
task_category = dataset_path.split('/')[-4]
|
||||
robot_name = dataset_path.split('/')[-3]
|
||||
task_name = dataset_path.split('/')[-2]
|
||||
collect_name = dataset_path.split('/')[-1]
|
||||
try:
|
||||
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
@@ -0,0 +1,181 @@
|
||||
"""Compute normalization statistics for real-world tasks.
|
||||
|
||||
This script is used to compute the normalization statistics for a given real-world task. It
|
||||
will compute the mean and standard deviation of the data in the dataset and save it
|
||||
to the config directory.
|
||||
"""
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import tqdm
|
||||
import tyro
|
||||
import json
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.normalize as normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.mixture_dataset as _mixture_dataset
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.transforms as transforms
|
||||
|
||||
### training config ###
|
||||
import openpi.training.weight_loaders as weight_loaders
|
||||
import openpi.models.pi0_config as pi0_config
|
||||
from openpi.training.config import MultiSim2RealSplitAlohaDataConfig, MultiDataConfig, DataConfig, TrainConfig
|
||||
|
||||
from pdb import set_trace
|
||||
|
||||
class RemoveStrings(transforms.DataTransformFn):
|
||||
def __call__(self, x: dict) -> dict:
|
||||
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
|
||||
|
||||
|
||||
def create_torch_dataloader(
|
||||
data_config: _config.DataConfig,
|
||||
action_horizon: int,
|
||||
batch_size: int,
|
||||
model_config: _model.BaseModelConfig,
|
||||
num_workers: int,
|
||||
max_frames: int | None = None,
|
||||
) -> tuple[_data_loader.Dataset, int]:
|
||||
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
|
||||
dataset = _mixture_dataset.TransformedDataset(
|
||||
dataset,
|
||||
[
|
||||
*data_config[0].repack_transforms.inputs,
|
||||
*data_config[0].data_transforms.inputs,
|
||||
RemoveStrings(),
|
||||
],
|
||||
)
|
||||
if max_frames is not None and max_frames < len(dataset):
|
||||
num_batches = max_frames // batch_size
|
||||
shuffle = True
|
||||
else:
|
||||
num_batches = len(dataset) // batch_size
|
||||
shuffle = False
|
||||
data_loader = _data_loader.TorchDataLoader(
|
||||
dataset,
|
||||
local_batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=shuffle,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return data_loader, num_batches
|
||||
|
||||
|
||||
def main(dataset_path, robot_name, task_name, save_path):
|
||||
if robot_name == "lift2":
|
||||
config = TrainConfig(
|
||||
name="lift2",
|
||||
model=pi0_config.Pi0Config(),
|
||||
data=[
|
||||
MultiSim2RealSplitAlohaDataConfig(
|
||||
repo_dir=dataset_path,
|
||||
task_id=None,
|
||||
use_gripper_aug=False,
|
||||
stats_dir='',
|
||||
base_config=MultiDataConfig(
|
||||
prompt_from_task=True,
|
||||
),
|
||||
asset_id=robot_name,
|
||||
robot_name=robot_name,
|
||||
repack_transforms=transforms.Group(
|
||||
inputs=[
|
||||
transforms.RepackTransform(
|
||||
{
|
||||
"state_dict": {
|
||||
"left_joint": "states.left_joint.position",
|
||||
"right_joint": "states.right_joint.position",
|
||||
"left_gripper": "states.left_gripper.position",
|
||||
"right_gripper": "states.right_gripper.position"
|
||||
},
|
||||
"action_dict": {
|
||||
"left_joint": "actions.left_joint.position",
|
||||
"right_joint": "actions.right_joint.position",
|
||||
"left_gripper": "actions.left_gripper.position",
|
||||
"right_gripper": "actions.right_gripper.position",
|
||||
"left_gripper_openness": "master_actions.left_gripper.openness",
|
||||
"right_gripper_openness": "master_actions.right_gripper.openness"
|
||||
},
|
||||
"prompt": "task"
|
||||
}
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
],
|
||||
# pretrain model path
|
||||
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
|
||||
pytorch_weight_path="checkpoints/pytorch/pi0_base",
|
||||
num_train_steps=30_000,
|
||||
num_workers=4,
|
||||
fsdp_devices=4,
|
||||
batch_size=8,
|
||||
)
|
||||
|
||||
data_config = config.data[0].create(config.model)
|
||||
print("done")
|
||||
output_path = os.path.join(save_path, robot_name, task_name)
|
||||
stats_json_path = os.path.join(output_path, "norm_stats.json")
|
||||
if os.path.isfile(stats_json_path):
|
||||
with open(stats_json_path, 'r', encoding='utf-8') as f:
|
||||
json.load(f)
|
||||
return True
|
||||
|
||||
data_loader, num_batches = create_torch_dataloader(
|
||||
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
|
||||
)
|
||||
|
||||
keys = ["state", "actions"]
|
||||
stats = {key: normalize.RunningStats() for key in keys}
|
||||
|
||||
step_id = 0
|
||||
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
|
||||
step_id += 1
|
||||
for key in keys:
|
||||
stats[key].update(np.asarray(batch[key]))
|
||||
if step_id > 10000:
|
||||
break
|
||||
|
||||
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
|
||||
|
||||
print(f"Writing stats to: {output_path}")
|
||||
normalize.save(output_path, norm_stats)
|
||||
|
||||
def check_lerobot_repo(repo_dir: str):
|
||||
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
|
||||
print(repo_dir, "true")
|
||||
return True
|
||||
else:
|
||||
print(repo_dir, "false")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task_path", type=str, default="data/InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/*")
|
||||
parser.add_argument("--robot_name", type=str, default="lift2")
|
||||
parser.add_argument("--save_path", type=str, default="stats/sim2real")
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
dataset_path=args.task_path
|
||||
save_path = args.save_path
|
||||
parts = dataset_path.split("/")
|
||||
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
|
||||
if robot_idx is None:
|
||||
raise ValueError(
|
||||
f"Cannot find robot name in path. Expected {args.robot_name}, "
|
||||
f"but got path: {dataset_path}"
|
||||
)
|
||||
|
||||
if robot_idx + 1 >= len(parts):
|
||||
raise ValueError(
|
||||
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
|
||||
)
|
||||
robot_name = parts[robot_idx]
|
||||
task_name = parts[robot_idx + 1]
|
||||
try:
|
||||
main(dataset_path, robot_name, task_name, save_path)
|
||||
except:
|
||||
print(dataset_path)
|
||||
|
||||
29
policy/openpi-InternData-A1/scripts/docker/compose.yml
Normal file
29
policy/openpi-InternData-A1/scripts/docker/compose.yml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Run with:
|
||||
# docker compose -f scripts/docker/compose.yml up --build
|
||||
services:
|
||||
openpi_server:
|
||||
image: openpi_server
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: scripts/docker/serve_policy.Dockerfile
|
||||
init: true
|
||||
tty: true
|
||||
network_mode: host
|
||||
# Populate configured openpi data home to /openpi_assets inside the container.
|
||||
# Populate aws credential inside the container.
|
||||
volumes:
|
||||
- $PWD:/app
|
||||
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
|
||||
environment:
|
||||
- SERVER_ARGS
|
||||
- OPENPI_DATA_HOME=/openpi_assets
|
||||
- IS_DOCKER=true
|
||||
|
||||
# Comment out this block if not running on a machine with GPUs.
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
37
policy/openpi-InternData-A1/scripts/docker/install_docker_ubuntu22.sh
Executable file
37
policy/openpi-InternData-A1/scripts/docker/install_docker_ubuntu22.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Add Docker's official GPG key:
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ca-certificates curl
|
||||
sudo install -m 0755 -d /etc/apt/keyrings
|
||||
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
|
||||
sudo chmod a+r /etc/apt/keyrings/docker.asc
|
||||
|
||||
# Add the repository to Apt sources:
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
|
||||
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
|
||||
sudo apt-get update
|
||||
|
||||
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
|
||||
|
||||
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
|
||||
# See https://docs.docker.com/engine/install/linux-postinstall/
|
||||
username=$(whoami)
|
||||
sudo usermod -aG docker $username
|
||||
|
||||
# Configure docker to start automatically on system boot.
|
||||
sudo systemctl enable docker.service
|
||||
sudo systemctl enable containerd.service
|
||||
|
||||
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
|
||||
if [ ~/.docker/config.json ]; then
|
||||
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "********************************************************************"
|
||||
echo "**** Restart to allow Docker permission changes to take effect. ****"
|
||||
echo "********************************************************************"
|
||||
echo ""
|
||||
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
|
||||
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
|
||||
|
||||
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
|
||||
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
|
||||
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
|
||||
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
|
||||
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
|
||||
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y nvidia-container-toolkit
|
||||
|
||||
sudo nvidia-ctk runtime configure --runtime=docker
|
||||
sudo systemctl restart docker
|
||||
@@ -0,0 +1,38 @@
|
||||
# Dockerfile for serving a PI policy.
|
||||
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
|
||||
|
||||
# Build the container:
|
||||
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
|
||||
|
||||
# Run the container:
|
||||
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
|
||||
|
||||
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Needed because LeRobot uses git-lfs.
|
||||
RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
|
||||
|
||||
# Copy from the cache instead of linking since it's a mounted volume
|
||||
ENV UV_LINK_MODE=copy
|
||||
|
||||
# Write the virtual environment outside of the project directory so it doesn't
|
||||
# leak out of the container when we mount the application code.
|
||||
ENV UV_PROJECT_ENVIRONMENT=/.venv
|
||||
|
||||
# Install the project's dependencies using the lockfile and settings
|
||||
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=uv.lock,target=uv.lock \
|
||||
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
|
||||
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
|
||||
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
|
||||
|
||||
# Copy transformers_replace files while preserving directory structure
|
||||
COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
|
||||
RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
|
||||
|
||||
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"
|
||||
27
policy/openpi-InternData-A1/scripts/download_paligemma.py
Normal file
27
policy/openpi-InternData-A1/scripts/download_paligemma.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
def download_from_gcs(gcs_uri: str, local_path: str):
|
||||
local_path = Path(local_path)
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if os.system("which gsutil > /dev/null 2>&1") == 0:
|
||||
cmd = f"gsutil cp {gcs_uri} {local_path}"
|
||||
else:
|
||||
gcs_http = gcs_uri.replace("gs://", "https://storage.googleapis.com/")
|
||||
cmd = f"wget -O {local_path} {gcs_http}"
|
||||
|
||||
print(f"⬇️ Executing: {cmd}")
|
||||
ret = os.system(cmd)
|
||||
if ret == 0:
|
||||
print("✅ Download complete:", local_path)
|
||||
else:
|
||||
raise RuntimeError(f"Download failed: {gcs_uri}")
|
||||
|
||||
return local_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
gcs_uri = "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz"
|
||||
save_path = "checkpoints/jax/paligemma/pt_224.npz"
|
||||
download_from_gcs(gcs_uri, save_path)
|
||||
122
policy/openpi-InternData-A1/scripts/serve_policy.py
Normal file
122
policy/openpi-InternData-A1/scripts/serve_policy.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import socket
|
||||
|
||||
import tyro
|
||||
|
||||
from openpi.policies import policy as _policy
|
||||
from openpi.policies import policy_config as _policy_config
|
||||
from openpi.serving import websocket_policy_server
|
||||
from openpi.training import config as _config
|
||||
|
||||
|
||||
class EnvMode(enum.Enum):
|
||||
"""Supported environments."""
|
||||
|
||||
ALOHA = "aloha"
|
||||
ALOHA_SIM = "aloha_sim"
|
||||
DROID = "droid"
|
||||
LIBERO = "libero"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Checkpoint:
|
||||
"""Load a policy from a trained checkpoint."""
|
||||
|
||||
# Training config name (e.g., "pi0_aloha_sim").
|
||||
config: str
|
||||
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
|
||||
dir: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Default:
|
||||
"""Use the default policy for the given environment."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Args:
|
||||
"""Arguments for the serve_policy script."""
|
||||
|
||||
# Environment to serve the policy for. This is only used when serving default policies.
|
||||
env: EnvMode = EnvMode.ALOHA_SIM
|
||||
|
||||
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
|
||||
# prompt.
|
||||
default_prompt: str | None = None
|
||||
|
||||
# Port to serve the policy on.
|
||||
port: int = 8000
|
||||
# Record the policy's behavior for debugging.
|
||||
record: bool = False
|
||||
|
||||
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
|
||||
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
|
||||
|
||||
|
||||
# Default checkpoints that should be used for each environment.
|
||||
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
|
||||
EnvMode.ALOHA: Checkpoint(
|
||||
config="pi05_aloha",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_base",
|
||||
),
|
||||
EnvMode.ALOHA_SIM: Checkpoint(
|
||||
config="pi0_aloha_sim",
|
||||
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
|
||||
),
|
||||
EnvMode.DROID: Checkpoint(
|
||||
config="pi05_droid",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_droid",
|
||||
),
|
||||
EnvMode.LIBERO: Checkpoint(
|
||||
config="pi05_libero",
|
||||
dir="gs://openpi-assets/checkpoints/pi05_libero",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
|
||||
"""Create a default policy for the given environment."""
|
||||
if checkpoint := DEFAULT_CHECKPOINT.get(env):
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
|
||||
)
|
||||
raise ValueError(f"Unsupported environment mode: {env}")
|
||||
|
||||
|
||||
def create_policy(args: Args) -> _policy.Policy:
|
||||
"""Create a policy from the given arguments."""
|
||||
match args.policy:
|
||||
case Checkpoint():
|
||||
return _policy_config.create_trained_policy(
|
||||
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
|
||||
)
|
||||
case Default():
|
||||
return create_default_policy(args.env, default_prompt=args.default_prompt)
|
||||
|
||||
|
||||
def main(args: Args) -> None:
|
||||
policy = create_policy(args)
|
||||
policy_metadata = policy.metadata
|
||||
|
||||
# Record the policy's behavior.
|
||||
if args.record:
|
||||
policy = _policy.PolicyRecorder(policy, "policy_records")
|
||||
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
|
||||
|
||||
server = websocket_policy_server.WebsocketPolicyServer(
|
||||
policy=policy,
|
||||
host="0.0.0.0",
|
||||
port=args.port,
|
||||
metadata=policy_metadata,
|
||||
)
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, force=True)
|
||||
main(tyro.cli(Args))
|
||||
290
policy/openpi-InternData-A1/scripts/train.py
Normal file
290
policy/openpi-InternData-A1/scripts/train.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
from memory_profiler import profile
|
||||
import psutil
|
||||
from openpi.shared.online_compute_norm_stats import compute_norm_stats
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
if config.online_compute_norm_stats:
|
||||
global_norm_stats = compute_norm_stats(config.name)
|
||||
else:
|
||||
global_norm_stats = None
|
||||
|
||||
data_loader = _data_loader.create_data_loader_multi(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
shuffle=True,
|
||||
global_norm_stats=global_norm_stats,
|
||||
)
|
||||
# @profile
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
print(psutil.Process().memory_info().rss/1024**2)
|
||||
# set_trace()
|
||||
# Log images from first batch to sanity check.
|
||||
images_to_log = [
|
||||
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
|
||||
for i in range(min(5, len(next(iter(batch[0].images.values())))))
|
||||
]
|
||||
wandb.log({"camera_views": images_to_log}, step=0)
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
pbar = tqdm.tqdm(
|
||||
range(start_step, config.num_train_steps),
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in pbar:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
|
||||
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
341
policy/openpi-InternData-A1/scripts/train_jax_multinode.py
Executable file
341
policy/openpi-InternData-A1/scripts/train_jax_multinode.py
Executable file
@@ -0,0 +1,341 @@
|
||||
"""
|
||||
Multi-host training entrypoint (JAX).
|
||||
|
||||
How to run multi-host (example: 2 nodes):
|
||||
# node0
|
||||
export JAX_COORDINATOR_ADDRESS=node0:12345
|
||||
export JAX_PROCESS_COUNT=2
|
||||
export JAX_PROCESS_INDEX=0
|
||||
uv run python scripts/train.py <config_name> --exp_name <exp>
|
||||
|
||||
# node1
|
||||
export JAX_COORDINATOR_ADDRESS=node0:12345
|
||||
export JAX_PROCESS_COUNT=2
|
||||
export JAX_PROCESS_INDEX=1
|
||||
uv run python scripts/train.py <config_name> --exp_name <exp>
|
||||
|
||||
Notes:
|
||||
- Initialize distributed BEFORE any device query.
|
||||
- Only process_index==0 performs side-effects (wandb, checkpoints, progress bar).
|
||||
- Total devices across hosts must be divisible by config.fsdp_devices.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import logging
|
||||
import platform
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import etils.epath as epath
|
||||
import flax.nnx as nnx
|
||||
from flax.training import common_utils
|
||||
import flax.traverse_util as traverse_util
|
||||
import jax
|
||||
import jax.experimental
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tqdm_loggable.auto as tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.model as _model
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.shared.nnx_utils as nnx_utils
|
||||
import openpi.training.checkpoints as _checkpoints
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data_loader
|
||||
import openpi.training.optimizer as _optimizer
|
||||
import openpi.training.sharding as sharding
|
||||
import openpi.training.utils as training_utils
|
||||
import openpi.training.weight_loaders as _weight_loaders
|
||||
from pdb import set_trace
|
||||
|
||||
|
||||
def maybe_initialize_distributed() -> bool:
|
||||
coordinator = os.environ.get("JAX_COORDINATOR_ADDRESS")
|
||||
process_count = int(os.environ.get("JAX_PROCESS_COUNT", "1"))
|
||||
process_index = int(os.environ.get("JAX_PROCESS_INDEX", "0"))
|
||||
if process_count > 1 and coordinator:
|
||||
jax.distributed.initialize(
|
||||
coordinator_address=coordinator,
|
||||
num_processes=process_count,
|
||||
process_id=process_index,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""Custom logging format for better readability."""
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
if not logger.handlers:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
else:
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
if log_code:
|
||||
wandb.run.log_code(epath.Path(__file__).parent.parent)
|
||||
|
||||
|
||||
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
|
||||
"""Loads and validates the weights. Returns a loaded subset of the weights."""
|
||||
loaded_params = loader.load(params_shape)
|
||||
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
|
||||
|
||||
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
|
||||
return traverse_util.unflatten_dict(
|
||||
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
|
||||
)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def init_train_state(
|
||||
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
|
||||
) -> tuple[training_utils.TrainState, Any]:
|
||||
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
|
||||
|
||||
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
|
||||
rng, model_rng = jax.random.split(rng)
|
||||
# initialize the model (and its parameters).
|
||||
model = config.model.create(model_rng)
|
||||
|
||||
# Merge the partial params into the model.
|
||||
if partial_params is not None:
|
||||
graphdef, state = nnx.split(model)
|
||||
# This will produce an error if the partial params are not a subset of the state.
|
||||
state.replace_by_pure_dict(partial_params)
|
||||
model = nnx.merge(graphdef, state)
|
||||
|
||||
params = nnx.state(model)
|
||||
# Convert frozen params to bfloat16.
|
||||
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
|
||||
|
||||
return training_utils.TrainState(
|
||||
step=0,
|
||||
params=params,
|
||||
model_def=nnx.graphdef(model),
|
||||
tx=tx,
|
||||
opt_state=tx.init(params.filter(config.trainable_filter)),
|
||||
ema_decay=config.ema_decay,
|
||||
ema_params=None if config.ema_decay is None else params,
|
||||
)
|
||||
|
||||
train_state_shape = jax.eval_shape(init, init_rng)
|
||||
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
|
||||
|
||||
if resume:
|
||||
return train_state_shape, state_sharding
|
||||
|
||||
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
# Initialize the train state and mix in the partial params.
|
||||
train_state = jax.jit(
|
||||
init,
|
||||
donate_argnums=(1,), # donate the partial params buffer.
|
||||
in_shardings=replicated_sharding,
|
||||
out_shardings=state_sharding,
|
||||
)(init_rng, partial_params)
|
||||
|
||||
return train_state, state_sharding
|
||||
|
||||
|
||||
@at.typecheck
|
||||
def train_step(
|
||||
config: _config.TrainConfig,
|
||||
rng: at.KeyArrayLike,
|
||||
state: training_utils.TrainState,
|
||||
batch: tuple[_model.Observation, _model.Actions],
|
||||
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
|
||||
model = nnx.merge(state.model_def, state.params)
|
||||
model.train()
|
||||
|
||||
@at.typecheck
|
||||
def loss_fn(
|
||||
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
|
||||
):
|
||||
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
|
||||
return jnp.mean(chunked_loss)
|
||||
# set_trace()
|
||||
train_rng = jax.random.fold_in(rng, state.step)
|
||||
observation, actions = batch
|
||||
|
||||
# Filter out frozen params.
|
||||
diff_state = nnx.DiffState(0, config.trainable_filter)
|
||||
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
|
||||
|
||||
params = state.params.filter(config.trainable_filter)
|
||||
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
|
||||
new_params = optax.apply_updates(params, updates)
|
||||
|
||||
# Update the model in place and return the new full state.
|
||||
nnx.update(model, new_params)
|
||||
new_params = nnx.state(model)
|
||||
|
||||
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
|
||||
if state.ema_decay is not None:
|
||||
new_state = dataclasses.replace(
|
||||
new_state,
|
||||
ema_params=jax.tree.map(
|
||||
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
|
||||
),
|
||||
)
|
||||
|
||||
# Filter out params that aren't kernels.
|
||||
kernel_params = nnx.state(
|
||||
model,
|
||||
nnx.All(
|
||||
nnx.Param,
|
||||
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
|
||||
lambda _, x: x.value.ndim > 1,
|
||||
),
|
||||
)
|
||||
info = {
|
||||
"loss": loss,
|
||||
"grad_norm": optax.global_norm(grads),
|
||||
"param_norm": optax.global_norm(kernel_params),
|
||||
}
|
||||
return new_state, info
|
||||
|
||||
|
||||
def main(config: _config.TrainConfig):
|
||||
init_logging()
|
||||
logging.info(f"Running on: {platform.node()}")
|
||||
|
||||
# Initialize multi-host distributed if environment variables are set
|
||||
distributed_initialized = maybe_initialize_distributed()
|
||||
is_main = jax.process_index() == 0
|
||||
|
||||
if config.batch_size % jax.device_count() != 0:
|
||||
raise ValueError(
|
||||
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
|
||||
)
|
||||
|
||||
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
|
||||
|
||||
rng = jax.random.key(config.seed)
|
||||
train_rng, init_rng = jax.random.split(rng)
|
||||
|
||||
mesh = sharding.make_mesh(config.fsdp_devices)
|
||||
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
|
||||
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
|
||||
config.checkpoint_dir,
|
||||
keep_period=config.keep_period,
|
||||
overwrite=config.overwrite,
|
||||
resume=config.resume,
|
||||
)
|
||||
init_wandb(config, resuming=resuming, enabled=(config.wandb_enabled and is_main))
|
||||
|
||||
data_loader = _data_loader.create_data_loader_multi(
|
||||
config,
|
||||
sharding=data_sharding,
|
||||
shuffle=True,
|
||||
)
|
||||
data_iter = iter(data_loader)
|
||||
batch = next(data_iter)
|
||||
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
|
||||
|
||||
# Note: Wandb image logging is disabled in multi-node setup to avoid potential hanging issues
|
||||
# caused by concurrent access to sharded arrays across processes.
|
||||
|
||||
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
|
||||
jax.block_until_ready(train_state)
|
||||
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
|
||||
|
||||
if resuming:
|
||||
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
|
||||
|
||||
ptrain_step = jax.jit(
|
||||
functools.partial(train_step, config),
|
||||
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
|
||||
out_shardings=(train_state_sharding, replicated_sharding),
|
||||
donate_argnums=(1,),
|
||||
)
|
||||
|
||||
start_step = int(train_state.step)
|
||||
step_iter = range(start_step, config.num_train_steps)
|
||||
pbar = (
|
||||
tqdm.tqdm(
|
||||
step_iter,
|
||||
initial=start_step,
|
||||
total=config.num_train_steps,
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
if is_main
|
||||
else None
|
||||
)
|
||||
|
||||
infos = []
|
||||
for step in step_iter:
|
||||
with sharding.set_mesh(mesh):
|
||||
train_state, info = ptrain_step(train_rng, train_state, batch)
|
||||
if is_main and pbar is not None:
|
||||
pbar.update(1)
|
||||
infos.append(info)
|
||||
if step % config.log_interval == 0:
|
||||
# print("log!")
|
||||
stacked_infos = common_utils.stack_forest(infos)
|
||||
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
|
||||
if is_main:
|
||||
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
|
||||
if pbar is not None:
|
||||
pbar.write(f"Step {step}: {info_str}")
|
||||
else:
|
||||
logging.info(f"Step {step}: {info_str}")
|
||||
if config.wandb_enabled:
|
||||
wandb.log(reduced_info, step=step)
|
||||
infos = []
|
||||
batch = next(data_iter)
|
||||
if ((step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1):
|
||||
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
|
||||
|
||||
if is_main:
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
logging.info("Waiting for checkpoint manager to finish")
|
||||
checkpoint_manager.wait_until_finished()
|
||||
|
||||
if distributed_initialized:
|
||||
jax.distributed.shutdown()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(_config.cli())
|
||||
632
policy/openpi-InternData-A1/scripts/train_pytorch.py
Normal file
632
policy/openpi-InternData-A1/scripts/train_pytorch.py
Normal file
@@ -0,0 +1,632 @@
|
||||
"""
|
||||
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
|
||||
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
|
||||
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
|
||||
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
|
||||
|
||||
Usage
|
||||
Single GPU:
|
||||
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
|
||||
Example:
|
||||
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
|
||||
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
|
||||
Multi-GPU (single node):
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
|
||||
Example:
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
|
||||
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
|
||||
Multi-Node Training:
|
||||
torchrun \
|
||||
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
|
||||
--master_addr=<master_ip> --master_port=<port> \
|
||||
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
|
||||
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.parallel
|
||||
import tqdm
|
||||
import wandb
|
||||
|
||||
import openpi.models.pi0_config
|
||||
import openpi.models_pytorch.pi0_pytorch
|
||||
import openpi.shared.normalize as _normalize
|
||||
import openpi.training.config as _config
|
||||
import openpi.training.data_loader as _data
|
||||
|
||||
|
||||
def init_logging():
|
||||
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
|
||||
|
||||
class CustomFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
record.levelname = level_mapping.get(record.levelname, record.levelname)
|
||||
return super().format(record)
|
||||
|
||||
formatter = CustomFormatter(
|
||||
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
if not logger.handlers:
|
||||
ch = logging.StreamHandler()
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
else:
|
||||
logger.handlers[0].setFormatter(formatter)
|
||||
|
||||
|
||||
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
|
||||
"""Initialize wandb logging."""
|
||||
if not enabled:
|
||||
wandb.init(mode="disabled")
|
||||
return
|
||||
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
if not ckpt_dir.exists():
|
||||
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
|
||||
|
||||
if resuming:
|
||||
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
|
||||
wandb.init(id=run_id, resume="must", project=config.project_name)
|
||||
else:
|
||||
wandb.init(
|
||||
name=config.exp_name,
|
||||
config=dataclasses.asdict(config),
|
||||
project=config.project_name,
|
||||
)
|
||||
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
|
||||
|
||||
|
||||
def setup_ddp():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
||||
use_ddp = world_size > 1
|
||||
if use_ddp and not torch.distributed.is_initialized():
|
||||
backend = "nccl" if torch.cuda.is_available() else "gloo"
|
||||
torch.distributed.init_process_group(backend=backend, init_method="env://")
|
||||
|
||||
# Set up debugging environment variables for DDP issues
|
||||
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
|
||||
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
|
||||
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(device)
|
||||
return use_ddp, local_rank, device
|
||||
|
||||
|
||||
def cleanup_ddp():
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def set_seed(seed: int, local_rank: int):
|
||||
torch.manual_seed(seed + local_rank)
|
||||
np.random.seed(seed + local_rank)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed + local_rank)
|
||||
|
||||
|
||||
def build_datasets(config: _config.TrainConfig):
|
||||
# Use the unified data loader with PyTorch framework
|
||||
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
|
||||
return data_loader, data_loader.data_config()
|
||||
|
||||
|
||||
def get_model_state_dict(model):
|
||||
"""Get state dict from model, handling DDP wrapper."""
|
||||
return (
|
||||
model.module.state_dict()
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
||||
else model.state_dict()
|
||||
)
|
||||
|
||||
|
||||
def get_model_parameters(model):
|
||||
"""Get parameters from model, handling DDP wrapper."""
|
||||
return (
|
||||
model.module.parameters()
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
|
||||
else model.parameters()
|
||||
)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
|
||||
"""Save a checkpoint with model state, optimizer state, and metadata."""
|
||||
if not is_main:
|
||||
return
|
||||
|
||||
# Only save if it's time to save or if it's the final step
|
||||
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
|
||||
# Create temporary directory for atomic checkpoint saving
|
||||
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
|
||||
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
|
||||
|
||||
# Remove any existing temp directory and create new one
|
||||
if tmp_ckpt_dir.exists():
|
||||
shutil.rmtree(tmp_ckpt_dir)
|
||||
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model state using safetensors (handle shared tensors)
|
||||
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
||||
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
|
||||
|
||||
# Save optimizer state using PyTorch format
|
||||
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
|
||||
|
||||
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
|
||||
metadata = {
|
||||
"global_step": global_step,
|
||||
"config": dataclasses.asdict(config),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
|
||||
|
||||
# save norm stats
|
||||
norm_stats = data_config.norm_stats
|
||||
if norm_stats is not None and data_config.asset_id is not None:
|
||||
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
|
||||
|
||||
# Atomically move temp directory to final location
|
||||
if final_ckpt_dir.exists():
|
||||
shutil.rmtree(final_ckpt_dir)
|
||||
tmp_ckpt_dir.rename(final_ckpt_dir)
|
||||
|
||||
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
|
||||
|
||||
# Log checkpoint to wandb
|
||||
if config.wandb_enabled:
|
||||
wandb.log({"checkpoint_step": global_step}, step=global_step)
|
||||
|
||||
|
||||
def load_checkpoint(model, optimizer, checkpoint_dir, device):
|
||||
"""Load the latest checkpoint and return the global step."""
|
||||
checkpoint_steps = [
|
||||
int(d.name)
|
||||
for d in checkpoint_dir.iterdir()
|
||||
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
||||
]
|
||||
|
||||
if not checkpoint_steps:
|
||||
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
|
||||
|
||||
latest_step = max(checkpoint_steps)
|
||||
ckpt_dir = checkpoint_dir / f"{latest_step}"
|
||||
|
||||
# Clear memory before loading checkpoints
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "before_loading_checkpoint")
|
||||
|
||||
try:
|
||||
# Load model state with error handling
|
||||
logging.info("Loading model state...")
|
||||
safetensors_path = ckpt_dir / "model.safetensors"
|
||||
|
||||
if safetensors_path.exists():
|
||||
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
|
||||
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
|
||||
logging.info("Loaded model state from safetensors format")
|
||||
else:
|
||||
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_model")
|
||||
|
||||
# Load optimizer state with error handling
|
||||
logging.info("Loading optimizer state...")
|
||||
optimizer_path = ckpt_dir / "optimizer.pt"
|
||||
|
||||
if optimizer_path.exists():
|
||||
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
|
||||
logging.info("Loaded optimizer state from pt format")
|
||||
else:
|
||||
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
|
||||
|
||||
optimizer.load_state_dict(optimizer_state_dict)
|
||||
del optimizer_state_dict
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_optimizer")
|
||||
|
||||
# Load metadata
|
||||
logging.info("Loading metadata...")
|
||||
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
|
||||
global_step = metadata.get("global_step", latest_step)
|
||||
del metadata
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
log_memory_usage(device, latest_step, "after_loading_metadata")
|
||||
|
||||
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
|
||||
return global_step
|
||||
|
||||
except RuntimeError as e:
|
||||
if "out of memory" in str(e):
|
||||
# Clear memory and provide detailed error message
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
|
||||
log_memory_usage(device, latest_step, "after_oom_error")
|
||||
raise RuntimeError(
|
||||
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def get_latest_checkpoint_step(checkpoint_dir):
|
||||
"""Get the latest checkpoint step number from a checkpoint directory."""
|
||||
checkpoint_steps = [
|
||||
int(d.name)
|
||||
for d in checkpoint_dir.iterdir()
|
||||
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
|
||||
]
|
||||
return max(checkpoint_steps) if checkpoint_steps else None
|
||||
|
||||
|
||||
def log_memory_usage(device, step, phase="unknown"):
|
||||
"""Log detailed memory usage information."""
|
||||
if not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
|
||||
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
|
||||
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
|
||||
memory_free = memory_free / 1e9
|
||||
|
||||
# Get more detailed memory info
|
||||
memory_stats = torch.cuda.memory_stats(device)
|
||||
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
|
||||
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
|
||||
|
||||
# Get DDP info if available
|
||||
ddp_info = ""
|
||||
if dist.is_initialized():
|
||||
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
|
||||
|
||||
logging.info(
|
||||
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
|
||||
)
|
||||
|
||||
|
||||
def train_loop(config: _config.TrainConfig):
|
||||
use_ddp, local_rank, device = setup_ddp()
|
||||
is_main = (not use_ddp) or (dist.get_rank() == 0)
|
||||
set_seed(config.seed, local_rank)
|
||||
|
||||
# Initialize checkpoint directory and wandb
|
||||
resuming = False
|
||||
if config.resume:
|
||||
# Find checkpoint directory based on experiment name
|
||||
exp_checkpoint_dir = config.checkpoint_dir
|
||||
if exp_checkpoint_dir.exists():
|
||||
# Use validation to find the latest working checkpoint
|
||||
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
|
||||
if latest_step is not None:
|
||||
resuming = True
|
||||
logging.info(
|
||||
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
|
||||
else:
|
||||
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
|
||||
elif config.overwrite and config.checkpoint_dir.exists():
|
||||
shutil.rmtree(config.checkpoint_dir)
|
||||
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
|
||||
|
||||
# Create checkpoint directory with experiment name
|
||||
if not resuming:
|
||||
# For new runs, create experiment-specific checkpoint directory
|
||||
exp_checkpoint_dir = config.checkpoint_dir
|
||||
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
|
||||
else:
|
||||
# For resume, checkpoint_dir is already set to the experiment directory
|
||||
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
|
||||
|
||||
# Initialize wandb (only on main process)
|
||||
if is_main:
|
||||
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
|
||||
|
||||
# Build data loader using the unified data loader
|
||||
# Calculate effective batch size per GPU for DDP
|
||||
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
|
||||
world_size = torch.distributed.get_world_size() if use_ddp else 1
|
||||
effective_batch_size = config.batch_size // world_size
|
||||
logging.info(
|
||||
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
|
||||
)
|
||||
|
||||
# Pass the original batch size to data loader - it will handle DDP splitting internally
|
||||
loader, data_config = build_datasets(config)
|
||||
|
||||
# Log sample images to wandb on first batch
|
||||
if is_main and config.wandb_enabled and not resuming:
|
||||
# Create a separate data loader for sample batch to avoid consuming the main loader
|
||||
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
|
||||
sample_batch = next(iter(sample_data_loader))
|
||||
# Convert observation and actions to torch tensors
|
||||
observation, actions = sample_batch
|
||||
sample_batch = observation.to_dict()
|
||||
sample_batch["actions"] = actions
|
||||
|
||||
# Create sample images for wandb
|
||||
images_to_log = []
|
||||
# Get batch size from the first image tensor
|
||||
batch_size = next(iter(sample_batch["image"].values())).shape[0]
|
||||
for i in range(min(5, batch_size)):
|
||||
# Concatenate all camera views horizontally for this batch item
|
||||
# Convert from NCHW to NHWC format for wandb
|
||||
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
|
||||
img_concatenated = img_concatenated.cpu().numpy()
|
||||
images_to_log.append(wandb.Image(img_concatenated))
|
||||
|
||||
wandb.log({"camera_views": images_to_log}, step=0)
|
||||
|
||||
# Clear sample batch from memory aggressively
|
||||
del sample_batch, observation, actions, images_to_log, img_concatenated
|
||||
del sample_data_loader # Also delete the sample data loader
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logging.info("Cleared sample batch and data loader from memory")
|
||||
|
||||
# Build model
|
||||
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
|
||||
# Convert dataclass to Pi0Config if needed
|
||||
model_cfg = openpi.models.pi0_config.Pi0Config(
|
||||
dtype=config.pytorch_training_precision,
|
||||
action_dim=config.model.action_dim,
|
||||
action_horizon=config.model.action_horizon,
|
||||
max_token_len=config.model.max_token_len,
|
||||
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
|
||||
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
|
||||
pi05=getattr(config.model, "pi05", False),
|
||||
)
|
||||
else:
|
||||
model_cfg = config.model
|
||||
# Update dtype to match pytorch_training_precision
|
||||
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
|
||||
|
||||
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
|
||||
|
||||
if hasattr(model, "gradient_checkpointing_enable"):
|
||||
enable_gradient_checkpointing = True
|
||||
model.gradient_checkpointing_enable()
|
||||
logging.info("Enabled gradient checkpointing for memory optimization")
|
||||
else:
|
||||
enable_gradient_checkpointing = False
|
||||
logging.info("Gradient checkpointing is not supported for this model")
|
||||
|
||||
# Log initial memory usage after model creation
|
||||
if is_main and torch.cuda.is_available():
|
||||
log_memory_usage(device, 0, "after_model_creation")
|
||||
|
||||
# Enable memory optimizations for large-scale training
|
||||
if world_size >= 8:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
# Set memory allocation configuration
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
|
||||
logging.info("Enabled memory optimizations for 8+ GPU training")
|
||||
|
||||
if use_ddp:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[device.index] if device.type == "cuda" else None,
|
||||
find_unused_parameters=False, # Disable for memory efficiency
|
||||
gradient_as_bucket_view=True, # Enable for memory efficiency
|
||||
static_graph=world_size >= 8, # Enable for 8+ GPUs
|
||||
)
|
||||
|
||||
# Load weights from weight_loader if specified (for fine-tuning)
|
||||
# if config.pytorch_weight_path is not None:
|
||||
# logging.info(f"Loading weights from: {config.pytorch_weight_path}")
|
||||
|
||||
# model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
|
||||
# safetensors.torch.load_model(
|
||||
# (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
|
||||
# )
|
||||
# logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
|
||||
|
||||
# Optimizer + learning rate schedule from config
|
||||
warmup_steps = config.lr_schedule.warmup_steps
|
||||
peak_lr = config.lr_schedule.peak_lr
|
||||
decay_steps = config.lr_schedule.decay_steps
|
||||
end_lr = config.lr_schedule.decay_lr
|
||||
|
||||
# Create optimizer with config parameters
|
||||
optim = torch.optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=peak_lr,
|
||||
betas=(config.optimizer.b1, config.optimizer.b2),
|
||||
eps=config.optimizer.eps,
|
||||
weight_decay=config.optimizer.weight_decay,
|
||||
)
|
||||
|
||||
# Load checkpoint if resuming
|
||||
global_step = 0
|
||||
if resuming:
|
||||
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
|
||||
logging.info(f"Resumed training from step {global_step}")
|
||||
|
||||
def lr_schedule(step: int):
|
||||
if step < warmup_steps:
|
||||
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
|
||||
init_lr = peak_lr / (warmup_steps + 1)
|
||||
return init_lr + (peak_lr - init_lr) * step / warmup_steps
|
||||
# cosine decay
|
||||
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
|
||||
cos = 0.5 * (1 + np.cos(np.pi * progress))
|
||||
return end_lr + (peak_lr - end_lr) * cos
|
||||
|
||||
model.train()
|
||||
start_time = time.time()
|
||||
infos = [] # Collect stats over log interval
|
||||
if is_main:
|
||||
logging.info(
|
||||
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
|
||||
)
|
||||
logging.info(
|
||||
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
|
||||
)
|
||||
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
|
||||
logging.info(
|
||||
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
|
||||
)
|
||||
logging.info(
|
||||
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
|
||||
)
|
||||
logging.info("EMA is not supported for PyTorch training")
|
||||
logging.info(f"Training precision: {model_cfg.dtype}")
|
||||
|
||||
# Training loop - iterate until we reach num_train_steps
|
||||
pbar = (
|
||||
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
|
||||
if is_main
|
||||
else None
|
||||
)
|
||||
|
||||
while global_step < config.num_train_steps:
|
||||
# Set epoch for distributed training
|
||||
if use_ddp and hasattr(loader, "set_epoch"):
|
||||
loader.set_epoch(global_step // len(loader))
|
||||
|
||||
for observation, actions in loader:
|
||||
# Check if we've reached the target number of steps
|
||||
if global_step >= config.num_train_steps:
|
||||
break
|
||||
|
||||
# The unified data loader returns (observation, actions) tuple
|
||||
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
|
||||
actions = actions.to(torch.float32) # noqa: PLW2901
|
||||
actions = actions.to(device) # noqa: PLW2901
|
||||
|
||||
# Update LR
|
||||
for pg in optim.param_groups:
|
||||
pg["lr"] = lr_schedule(global_step)
|
||||
|
||||
# Forward pass
|
||||
losses = model(observation, actions)
|
||||
# Ensure losses is a tensor and handle different return types
|
||||
if isinstance(losses, list | tuple):
|
||||
losses = torch.stack(losses)
|
||||
elif not isinstance(losses, torch.Tensor):
|
||||
losses = torch.tensor(losses, device=device, dtype=torch.float32)
|
||||
|
||||
loss = losses.mean()
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Log memory usage after backward pass
|
||||
if global_step < 5 and is_main and torch.cuda.is_available():
|
||||
log_memory_usage(device, global_step, "after_backward")
|
||||
|
||||
# Gradient clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
|
||||
|
||||
# Optimizer step
|
||||
optim.step()
|
||||
optim.zero_grad(set_to_none=True)
|
||||
|
||||
# Clear gradients more aggressively
|
||||
for param in model.parameters():
|
||||
if param.grad is not None:
|
||||
param.grad.detach_()
|
||||
param.grad = None
|
||||
|
||||
# Collect stats
|
||||
if is_main:
|
||||
infos.append(
|
||||
{
|
||||
"loss": loss.item(),
|
||||
"learning_rate": optim.param_groups[0]["lr"],
|
||||
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
|
||||
}
|
||||
)
|
||||
|
||||
if is_main and (global_step % config.log_interval == 0):
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Average stats over log interval
|
||||
avg_loss = sum(info["loss"] for info in infos) / len(infos)
|
||||
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
|
||||
|
||||
avg_grad_norm = None
|
||||
if any("grad_norm" in info for info in infos):
|
||||
vals = [
|
||||
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
|
||||
]
|
||||
if len(vals) > 0:
|
||||
avg_grad_norm = sum(vals) / len(vals)
|
||||
logging.info(
|
||||
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
|
||||
if avg_grad_norm is not None
|
||||
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
# Log to wandb
|
||||
if config.wandb_enabled and len(infos) > 0:
|
||||
log_payload = {
|
||||
"loss": avg_loss,
|
||||
"learning_rate": avg_lr,
|
||||
"step": global_step,
|
||||
"time_per_step": elapsed / config.log_interval,
|
||||
}
|
||||
if avg_grad_norm is not None:
|
||||
log_payload["grad_norm"] = avg_grad_norm
|
||||
wandb.log(log_payload, step=global_step)
|
||||
|
||||
start_time = time.time()
|
||||
infos = [] # Reset stats collection
|
||||
|
||||
global_step += 1
|
||||
# Save checkpoint using the new mechanism
|
||||
save_checkpoint(model, optim, global_step, config, is_main, data_config)
|
||||
|
||||
# Update progress bar
|
||||
if pbar is not None:
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(
|
||||
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
|
||||
)
|
||||
|
||||
# Close progress bar
|
||||
if pbar is not None:
|
||||
pbar.close()
|
||||
|
||||
# Finish wandb run
|
||||
if is_main and config.wandb_enabled:
|
||||
wandb.finish()
|
||||
|
||||
cleanup_ddp()
|
||||
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
config = _config.cli()
|
||||
train_loop(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
30
policy/openpi-InternData-A1/scripts/train_test.py
Normal file
30
policy/openpi-InternData-A1/scripts/train_test.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dataclasses
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
from openpi.training import config as _config
|
||||
|
||||
from . import train
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_name", ["debug"])
|
||||
def test_train(tmp_path: pathlib.Path, config_name: str):
|
||||
config = dataclasses.replace(
|
||||
_config._CONFIGS_DICT[config_name], # noqa: SLF001
|
||||
batch_size=2,
|
||||
checkpoint_base_dir=str(tmp_path / "checkpoint"),
|
||||
exp_name="test",
|
||||
overwrite=False,
|
||||
resume=False,
|
||||
num_train_steps=2,
|
||||
log_interval=1,
|
||||
)
|
||||
train.main(config)
|
||||
|
||||
# test resuming
|
||||
config = dataclasses.replace(config, resume=True, num_train_steps=4)
|
||||
train.main(config)
|
||||
209
policy/openpi-InternData-A1/scripts/training_scripts/multi_node.sh
Executable file
209
policy/openpi-InternData-A1/scripts/training_scripts/multi_node.sh
Executable file
@@ -0,0 +1,209 @@
|
||||
#!/usr/bin/env bash
|
||||
set -ex
|
||||
|
||||
cd YOUR_PATH/openpi
|
||||
|
||||
export USE_TF=0
|
||||
export USE_TORCH=0
|
||||
export USE_JAX=1
|
||||
export IMAGEIO_FFMPEG_EXE=ffmpeg
|
||||
# JAX GPU memory fraction
|
||||
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}"
|
||||
|
||||
# ============================================================================
|
||||
# NCCL Configuration
|
||||
# ============================================================================
|
||||
export NCCL_ASYNC_ERROR_HANDLING=1
|
||||
export NCCL_TIMEOUT=3600
|
||||
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
|
||||
|
||||
# ============================================================================
|
||||
# Platform-Injected Configuration
|
||||
# ============================================================================
|
||||
# The platform automatically injects these when DISTRIBUTED_JOB=true:
|
||||
# - NCCL_IB_HCA, NCCL_IB_GID_INDEX, NCCL_SOCKET_IFNAME
|
||||
# - NODE_RANK, NODE_COUNT, MASTER_ADDR, PROC_PER_NODE
|
||||
# - CUDA_VISIBLE_DEVICES
|
||||
# We trust and use these platform configurations directly.
|
||||
# ============================================================================
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Platform Configuration"
|
||||
echo "=========================================="
|
||||
echo "NODE_RANK: ${NODE_RANK:-<not set>}"
|
||||
echo "NODE_COUNT: ${NODE_COUNT:-<not set>}"
|
||||
echo "MASTER_ADDR: ${MASTER_ADDR:-<not set>}"
|
||||
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
|
||||
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
|
||||
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# NCCL Transport Configuration
|
||||
# ============================================================================
|
||||
# Use platform-injected configuration if available, otherwise fallback
|
||||
# ============================================================================
|
||||
|
||||
if [ -n "${NCCL_IB_HCA:-}" ]; then
|
||||
# Platform has configured InfiniBand
|
||||
echo "[NCCL] ✓ Using platform-injected InfiniBand configuration"
|
||||
|
||||
# Only set NCCL_NET if not already set
|
||||
if [ -z "${NCCL_NET:-}" ]; then
|
||||
export NCCL_NET="IB"
|
||||
fi
|
||||
|
||||
# Set IB timeout if not already set
|
||||
if [ -z "${NCCL_IB_TIMEOUT:-}" ]; then
|
||||
export NCCL_IB_TIMEOUT=23
|
||||
fi
|
||||
|
||||
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
|
||||
echo "[NCCL] NCCL_IB_HCA: ${NCCL_IB_HCA}"
|
||||
echo "[NCCL] NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX}"
|
||||
echo "[NCCL] NCCL_IB_TIMEOUT: ${NCCL_IB_TIMEOUT}"
|
||||
|
||||
elif [ -n "${NCCL_SOCKET_IFNAME:-}" ]; then
|
||||
# Platform has configured Socket
|
||||
echo "[NCCL] ✓ Using platform-injected Socket configuration"
|
||||
|
||||
if [ -z "${NCCL_NET:-}" ]; then
|
||||
export NCCL_NET="Socket"
|
||||
fi
|
||||
|
||||
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
|
||||
echo "[NCCL] NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME}"
|
||||
|
||||
else
|
||||
# No platform injection - use OPENPI_NCCL_NET preference
|
||||
echo "[NCCL] ⚠️ No platform-injected NCCL configuration"
|
||||
|
||||
if [ "${OPENPI_NCCL_NET:-IB}" = "IB" ]; then
|
||||
echo "[NCCL] ✗ InfiniBand requested but not configured by platform"
|
||||
echo "[NCCL] ✗ Falling back to Socket transport"
|
||||
export NCCL_NET="Socket"
|
||||
export NCCL_IB_DISABLE=1
|
||||
else
|
||||
export NCCL_NET="Socket"
|
||||
export NCCL_IB_DISABLE=1
|
||||
echo "[NCCL] Using Socket transport"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# JAX Distributed Configuration
|
||||
# ============================================================================
|
||||
# Map platform variables to JAX variables
|
||||
# ============================================================================
|
||||
|
||||
echo "=========================================="
|
||||
echo "JAX Distributed Configuration"
|
||||
echo "=========================================="
|
||||
|
||||
JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-12345}"
|
||||
|
||||
# Set JAX coordinator address
|
||||
if [ -z "${JAX_COORDINATOR_ADDRESS:-}" ] && [ -n "${MASTER_ADDR:-}" ]; then
|
||||
export JAX_COORDINATOR_ADDRESS="${MASTER_ADDR}:${JAX_COORDINATOR_PORT}"
|
||||
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS} (from MASTER_ADDR)"
|
||||
elif [ -n "${JAX_COORDINATOR_ADDRESS:-}" ]; then
|
||||
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS}"
|
||||
else
|
||||
echo "[JAX] ✗ WARNING: No coordinator address set!"
|
||||
fi
|
||||
|
||||
# Set JAX process count
|
||||
if [ -z "${JAX_PROCESS_COUNT:-}" ] && [ -n "${NODE_COUNT:-}" ]; then
|
||||
export JAX_PROCESS_COUNT="${NODE_COUNT}"
|
||||
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT} (from NODE_COUNT)"
|
||||
elif [ -n "${JAX_PROCESS_COUNT:-}" ]; then
|
||||
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT}"
|
||||
fi
|
||||
|
||||
# Set JAX process index
|
||||
if [ -z "${JAX_PROCESS_INDEX:-}" ] && [ -n "${NODE_RANK:-}" ]; then
|
||||
export JAX_PROCESS_INDEX="${NODE_RANK}"
|
||||
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX} (from NODE_RANK)"
|
||||
elif [ -n "${JAX_PROCESS_INDEX:-}" ]; then
|
||||
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX}"
|
||||
fi
|
||||
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Python Environment
|
||||
# ============================================================================
|
||||
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
|
||||
conda activate pi0
|
||||
|
||||
# ============================================================================
|
||||
# Configuration Summary
|
||||
# ============================================================================
|
||||
|
||||
echo "=========================================="
|
||||
echo "Configuration Summary"
|
||||
echo "=========================================="
|
||||
echo "NCCL_NET: ${NCCL_NET:-<not set>}"
|
||||
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
|
||||
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
|
||||
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
|
||||
echo "JAX_COORDINATOR: ${JAX_COORDINATOR_ADDRESS:-<not set>}"
|
||||
echo "JAX_PROCESS_COUNT: ${JAX_PROCESS_COUNT:-<not set>}"
|
||||
echo "JAX_PROCESS_INDEX: ${JAX_PROCESS_INDEX:-<not set>}"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
# ============================================================================
|
||||
# Display Host Information
|
||||
# ============================================================================
|
||||
|
||||
python - <<'EOF'
|
||||
import socket
|
||||
import os
|
||||
import jax
|
||||
hostname = socket.gethostname()
|
||||
devices = jax.local_devices()
|
||||
device_count = len(devices)
|
||||
device_ids = [d.id for d in devices]
|
||||
print(f"[JAX] host={hostname}, devices={device_count}xgpu, ids={device_ids}")
|
||||
print(f"[JAX] JAX_COORDINATOR_ADDRESS={os.environ.get('JAX_COORDINATOR_ADDRESS', '<not set>')}")
|
||||
print(f"[JAX] JAX_PROCESS_COUNT={os.environ.get('JAX_PROCESS_COUNT', '<not set>')}")
|
||||
print(f"[JAX] JAX_PROCESS_INDEX={os.environ.get('JAX_PROCESS_INDEX', '<not set>')}")
|
||||
EOF
|
||||
|
||||
# ============================================================================
|
||||
# Launch Training
|
||||
# ============================================================================
|
||||
|
||||
# Determine experiment name based on transport
|
||||
if [ "${OPENPI_DEBUG_SINGLE_GPU:-0}" = "1" ]; then
|
||||
EXP_NAME="${EXP_NAME:-dev_jax_single_gpu}"
|
||||
echo "[DEBUG] Running in single-GPU mode"
|
||||
else
|
||||
EXP_NAME="${EXP_NAME:-dev_jax_multinode_ib}"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo "Starting Training"
|
||||
echo "=========================================="
|
||||
echo "Experiment: $EXP_NAME"
|
||||
echo "=========================================="
|
||||
echo ""
|
||||
|
||||
ulimit -n 1000000
|
||||
|
||||
python scripts/train_jax_multinode.py \
|
||||
pretrain-interndata-a1 \
|
||||
--exp-name=pretrain-interndata-a1 \
|
||||
--num_workers=12 \
|
||||
--fsdp_devices=8 \
|
||||
--batch_size=512 \
|
||||
--num_train_steps=2000000 \
|
||||
--save_interval=5000
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
set -ex
|
||||
|
||||
export IMAGEIO_FFMPEG_EXE=ffmpeg
|
||||
export OMP_NUM_THREADS=128
|
||||
|
||||
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
|
||||
conda activate pi0
|
||||
|
||||
cd YOUR_PATH/openpi
|
||||
ulimit -n 1000000
|
||||
config_name=$1
|
||||
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python scripts/train.py ${config_name} \
|
||||
--exp-name=${config_name}
|
||||
0
policy/openpi-InternData-A1/src/openpi/__init__.py
Normal file
0
policy/openpi-InternData-A1/src/openpi/__init__.py
Normal file
17
policy/openpi-InternData-A1/src/openpi/conftest.py
Normal file
17
policy/openpi-InternData-A1/src/openpi/conftest.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
|
||||
import pynvml
|
||||
import pytest
|
||||
|
||||
|
||||
def set_jax_cpu_backend_if_no_gpu() -> None:
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
pynvml.nvmlShutdown()
|
||||
except pynvml.NVMLError:
|
||||
# No GPU found.
|
||||
os.environ["JAX_PLATFORMS"] = "cpu"
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
set_jax_cpu_backend_if_no_gpu()
|
||||
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
459
policy/openpi-InternData-A1/src/openpi/models/gemma.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Gemma adaptation for Pi, taken from big_vision.
|
||||
|
||||
We follow this einsum axis naming convention:
|
||||
B: batch
|
||||
T: query length
|
||||
S: k/v length
|
||||
N: num query heads
|
||||
K: num k/v heads
|
||||
G: num query heads per k/v head
|
||||
H: head dim
|
||||
D: d_model ("features")
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
import openpi.training.sharding as sharding
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257_152
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Config:
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
|
||||
|
||||
|
||||
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant: Variant) -> Config:
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "dummy":
|
||||
return Config(
|
||||
width=64,
|
||||
depth=4,
|
||||
mlp_dim=128,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=16,
|
||||
)
|
||||
if variant == "gemma_300m":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return Config(
|
||||
width=2048,
|
||||
depth=18,
|
||||
mlp_dim=16_384,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
|
||||
)
|
||||
if variant == "gemma_300m_lora":
|
||||
# 311M params
|
||||
return Config(
|
||||
width=1024,
|
||||
depth=18,
|
||||
mlp_dim=4096,
|
||||
num_heads=8,
|
||||
num_kv_heads=1,
|
||||
head_dim=256,
|
||||
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x, cond):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
if cond is None:
|
||||
# regular RMSNorm
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype), None # return in original dtype
|
||||
|
||||
# adaptive RMSNorm
|
||||
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
|
||||
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
|
||||
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
|
||||
return normed_inputs.astype(dtype), gate
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.normal(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
configs: Sequence[Config]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, positions, attn_mask, kv_cache):
|
||||
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
|
||||
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
|
||||
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
|
||||
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
|
||||
|
||||
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
|
||||
|
||||
qkvs = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is None:
|
||||
continue
|
||||
if config.num_kv_heads == config.num_heads:
|
||||
qkv_einsum = lora.Einsum(
|
||||
shape=(3, config.num_heads, config.width, config.head_dim),
|
||||
name=_name("qkv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
|
||||
else:
|
||||
q_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.width, config.head_dim),
|
||||
name=_name("q_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
q = q_einsum("BTD,NDH->BTNH", x)
|
||||
kv_einsum = lora.Einsum(
|
||||
shape=(2, config.num_kv_heads, config.width, config.head_dim),
|
||||
name=_name("kv_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
qkvs.append((q, k, v))
|
||||
|
||||
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
|
||||
|
||||
q = _apply_rope(q, positions=positions)
|
||||
q *= self.configs[0].head_dim ** -0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions)
|
||||
|
||||
# should still be half-precision here (if input was half-precision)
|
||||
assert q.dtype == k.dtype == v.dtype == dtype
|
||||
|
||||
if kv_cache is not None:
|
||||
cache_k, cache_v = kv_cache
|
||||
k = jnp.concatenate([cache_k, k], axis=1)
|
||||
v = jnp.concatenate([cache_v, v], axis=1)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
|
||||
out = []
|
||||
start = 0
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
end = start + x.shape[1]
|
||||
out_einsum = lora.Einsum(
|
||||
shape=(config.num_heads, config.head_dim, config.width),
|
||||
name=_name("attn_vec_einsum", i),
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
|
||||
lora_config=config.lora_configs.get("attn"),
|
||||
)
|
||||
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
|
||||
start = end
|
||||
else:
|
||||
out.append(None)
|
||||
|
||||
return out, (k, v)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
).astype(dtype)
|
||||
ff_gate = jnp.dot(x, w_gating[0])
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = jnp.dot(x, w_gating[1])
|
||||
activations = gate_value * ff1
|
||||
|
||||
w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
).astype(dtype)
|
||||
outputs = jnp.dot(activations, w_linear)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
configs: tuple[Config, ...]
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
|
||||
|
||||
attn = Attention(configs=self.configs, name="attn")
|
||||
|
||||
pre_attn = []
|
||||
gates = []
|
||||
for i, x in enumerate(xs):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
pre_attn.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
pre_attn = sharding.activation_sharding_constraint(pre_attn)
|
||||
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
|
||||
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
|
||||
post_attn = sharding.activation_sharding_constraint(post_attn)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
out = []
|
||||
gates = []
|
||||
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
|
||||
if x is not None:
|
||||
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
|
||||
x = lora.FeedForward( # noqa: PLW2901
|
||||
features=config.width,
|
||||
hidden_dim=config.mlp_dim,
|
||||
name=_name("mlp", i),
|
||||
lora_config=config.lora_configs.get("ffn"),
|
||||
)(x)
|
||||
out.append(x)
|
||||
gates.append(gate if x is not None else None)
|
||||
|
||||
out = sharding.activation_sharding_constraint(out)
|
||||
out = jax.tree.map(lambda x: drop(x, deterministic), out)
|
||||
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
|
||||
xs = sharding.activation_sharding_constraint(xs)
|
||||
|
||||
return xs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""Transformer model, supporting a mixture of different weights for different tokens."""
|
||||
|
||||
configs: Sequence[Config] # list of configs, one for each expert
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
adarms: bool = False
|
||||
|
||||
def setup(self):
|
||||
# all experts must have the same depth
|
||||
assert all(config.depth == self.configs[0].depth for config in self.configs)
|
||||
|
||||
self.embedder = Embedder(
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
embed_dim=self.configs[0].width, # embedder for first expert only
|
||||
name="embedder",
|
||||
)
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=False,
|
||||
static_argnums=(5,), # 0=self, 6=deterministic
|
||||
policy=jax.checkpoint_policies.nothing_saveable,
|
||||
)
|
||||
self.layers = nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(
|
||||
0,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
nn.broadcast,
|
||||
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
|
||||
length=self.configs[0].depth,
|
||||
)(
|
||||
configs=self.configs,
|
||||
dropout=self.dropout,
|
||||
dropout_bdims=self.dropout_bdims,
|
||||
)
|
||||
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
|
||||
|
||||
@at.typecheck
|
||||
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
|
||||
return self.embedder.encode(tokens).astype(self.embed_dtype)
|
||||
|
||||
@at.typecheck
|
||||
def __call__(
|
||||
self,
|
||||
# list of token arrays, one for each expert, or None if that expert should not be run
|
||||
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
|
||||
positions: at.Int[at.Array, "b t"],
|
||||
mask: at.Bool[at.Array, "b t s"],
|
||||
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
|
||||
*,
|
||||
kv_cache: KVCache | None = None,
|
||||
deterministic: bool = True,
|
||||
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
|
||||
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
|
||||
mask = jnp.asarray(mask)[:, None, :, :]
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None] * len(self.configs)
|
||||
|
||||
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
|
||||
|
||||
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
|
||||
|
||||
return [
|
||||
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
|
||||
], kv_cache
|
||||
|
||||
def init(self, use_adarms: Sequence[bool]):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
self(
|
||||
[jnp.zeros((1, 1, c.width)) for c in self.configs],
|
||||
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
|
||||
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
|
||||
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
|
||||
)
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
|
||||
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
|
||||
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
|
||||
# here.
|
||||
return res.astype(x.dtype)
|
||||
|
||||
|
||||
def _name(name, i):
|
||||
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
|
||||
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
|
||||
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
|
||||
# and the action expert.
|
||||
if i == 0:
|
||||
return name
|
||||
return f"{name}_{i}"
|
||||
|
||||
|
||||
def _gated_residual(x, y, gate):
|
||||
assert (x is None) == (y is None)
|
||||
if x is None:
|
||||
return None
|
||||
if gate is None:
|
||||
return x + y
|
||||
return x + y * gate
|
||||
437
policy/openpi-InternData-A1/src/openpi/models/gemma_fast.py
Normal file
437
policy/openpi-InternData-A1/src/openpi/models/gemma_fast.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# Copyright 2024 Big Vision Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
|
||||
Used for FAST autoregressive policies.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
import einops
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_collections
|
||||
|
||||
import openpi.models.lora as lora
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
Variant = Literal["gemma_2b", "gemma_2b_lora"]
|
||||
|
||||
|
||||
def get_config(variant):
|
||||
"""Returns config for specified gemma variant."""
|
||||
if variant == "gemma_2b":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
}
|
||||
)
|
||||
if variant == "gemma_2b_lora":
|
||||
return ml_collections.ConfigDict(
|
||||
{
|
||||
"variant": variant,
|
||||
"width": 2048,
|
||||
"depth": 18,
|
||||
"mlp_dim": 16_384,
|
||||
"num_heads": 8,
|
||||
"num_kv_heads": 1,
|
||||
"head_dim": 256,
|
||||
"norm_eps": 1e-6,
|
||||
"vocab_size": 257_152,
|
||||
"scan": True,
|
||||
"remat_policy": "nothing_saveable",
|
||||
"lora_configs": {
|
||||
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
|
||||
},
|
||||
}
|
||||
)
|
||||
raise ValueError(f"Unknown variant: {variant}")
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Einsum(nn.Module):
|
||||
shape: tuple[int, ...]
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
|
||||
return jnp.einsum(eqn, x, w)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class RMSNorm(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
|
||||
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
|
||||
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
|
||||
normed_inputs = normed_inputs * (
|
||||
1 + scale
|
||||
) # scale by learned parameter in float32 (matches Flax implementation)
|
||||
return normed_inputs.astype(dtype) # return in original dtype
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Embedder(nn.Module):
|
||||
"""Embedder module."""
|
||||
|
||||
vocab_size: int
|
||||
embed_dim: int
|
||||
|
||||
def setup(self):
|
||||
self.input_embedding_table = self.param(
|
||||
"input_embedding",
|
||||
nn.initializers.zeros_init(),
|
||||
(self.vocab_size, self.embed_dim),
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
x = self.input_embedding_table[(x,)]
|
||||
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return jnp.dot(x, self.input_embedding_table.T)
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Attention(nn.Module):
|
||||
"""Attention module."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
features: int
|
||||
head_dim: int
|
||||
|
||||
cache_dtype: str | None = None
|
||||
|
||||
lora_config: lora.LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
self.qkv_einsum = lora.Einsum(
|
||||
shape=(3, self.num_heads, self.features, self.head_dim),
|
||||
name="qkv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
else:
|
||||
self.q_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.features, self.head_dim),
|
||||
name="q_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.kv_einsum = lora.Einsum(
|
||||
shape=(2, self.num_kv_heads, self.features, self.head_dim),
|
||||
name="kv_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
self.attn_vec_einsum = lora.Einsum(
|
||||
shape=(self.num_heads, self.head_dim, self.features),
|
||||
name="attn_vec_einsum",
|
||||
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
lora_config=self.lora_config,
|
||||
)
|
||||
|
||||
def _init_cache(self, k, v, cache_size):
|
||||
"""Initialize KV cache"""
|
||||
prefill_len = k.shape[1]
|
||||
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
|
||||
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
|
||||
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
|
||||
return idx, k_cache, v_cache
|
||||
|
||||
def _update_cache(self, k, v, idx, k_cache, v_cache):
|
||||
"""Update KV cache with new values"""
|
||||
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
|
||||
indices = (0, idx[0], 0, 0)
|
||||
cache_dtype = self.cache_dtype or k.dtype
|
||||
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
|
||||
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
|
||||
idx_new = idx + 1
|
||||
return idx_new, k_new, v_new
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
if self.num_kv_heads == self.num_heads:
|
||||
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
|
||||
else:
|
||||
q = self.q_einsum("BTD,NDH->BTNH", x)
|
||||
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
|
||||
|
||||
q = _apply_rope(q, positions=positions) # promotes to float32
|
||||
q *= self.head_dim**-0.5
|
||||
|
||||
k = _apply_rope(k, positions=positions) # promotes to float32
|
||||
|
||||
if kv_cache is None:
|
||||
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
|
||||
else:
|
||||
idx, k_cache, v_cache = kv_cache
|
||||
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
|
||||
|
||||
k, v = k_cache, v_cache
|
||||
kv_cache = (idx, k_cache, v_cache)
|
||||
|
||||
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
|
||||
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
|
||||
|
||||
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
|
||||
raise ValueError(
|
||||
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
|
||||
)
|
||||
|
||||
# big_neg = jnp.finfo(logits.dtype).min
|
||||
big_neg = -2.3819763e38 # See gemma/modules.py
|
||||
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
|
||||
|
||||
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
|
||||
|
||||
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
|
||||
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
|
||||
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Block(nn.Module):
|
||||
"""Transformer block."""
|
||||
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
embed_dim: int
|
||||
head_dim: int
|
||||
hidden_dim: int
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = ()
|
||||
cache_dtype: str | None = None
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
def setup(self):
|
||||
self.pre_attention_norm = RMSNorm()
|
||||
self.attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
features=self.embed_dim,
|
||||
head_dim=self.head_dim,
|
||||
cache_dtype=self.cache_dtype,
|
||||
lora_config=self.lora_configs.get("attn"),
|
||||
)
|
||||
self.pre_ffw_norm = RMSNorm()
|
||||
self.mlp = lora.FeedForward(
|
||||
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
|
||||
)
|
||||
if self.dropout:
|
||||
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
|
||||
else:
|
||||
self.drop = lambda x, _: x
|
||||
|
||||
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
|
||||
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
|
||||
inputs_normalized = self.pre_attention_norm(x)
|
||||
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
|
||||
attn_output = self.drop(attn_output, deterministic)
|
||||
attn_output += x
|
||||
residual = attn_output
|
||||
attn_output = self.pre_ffw_norm(attn_output)
|
||||
outputs = self.mlp(attn_output)
|
||||
outputs = self.drop(outputs, deterministic)
|
||||
outputs = residual + outputs
|
||||
return outputs, kv_cache
|
||||
|
||||
|
||||
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
|
||||
|
||||
|
||||
@at.typecheck
|
||||
class Module(nn.Module):
|
||||
"""gemma model."""
|
||||
|
||||
variant: str
|
||||
|
||||
width: int
|
||||
depth: int
|
||||
mlp_dim: int
|
||||
num_heads: int
|
||||
num_kv_heads: int
|
||||
head_dim: int
|
||||
norm_eps: float
|
||||
vocab_size: int
|
||||
embed_dtype: str
|
||||
|
||||
dropout: float = 0.0
|
||||
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
|
||||
cache_dtype: str | None = None
|
||||
|
||||
scan: bool = False
|
||||
remat_policy: str = "none"
|
||||
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
|
||||
|
||||
@nn.compact
|
||||
def __call__(
|
||||
self,
|
||||
tokens=None,
|
||||
embedded_prefix=None,
|
||||
embed_only=False, # noqa: FBT002
|
||||
pre_logits=None,
|
||||
positions=None,
|
||||
mask=None,
|
||||
decode=False, # noqa: FBT002
|
||||
kv_cache=None,
|
||||
deterministic=True, # noqa: FBT002
|
||||
return_prelogits=False, # noqa: FBT002
|
||||
):
|
||||
"""Embed only, or complete forward pass.
|
||||
|
||||
Args:
|
||||
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
|
||||
embedded_prefix: Optional prefix that is already embedded.
|
||||
embed_only: Whether to compute embeddings only.
|
||||
pre_logits: If present computes logits from pre_logits and returns.
|
||||
positions: Optional `[B, T]` allows to specify the absolute position of
|
||||
the tokens.
|
||||
mask: Optional attention mask `[B, T, S]`.
|
||||
decode: Whether to use kv-cache. Caller must pass masks and positions.
|
||||
deterministic: Forwarded to all dropout layers.
|
||||
return_prelogits: Whether to return the pre-logits.
|
||||
|
||||
Returns:
|
||||
If `embed_only=False`, then `(logits, out)` will be returned.
|
||||
If `embed_only=True`, then the embeddings will be returned.
|
||||
If `return_prelogits=True`, then the pre-logits will be returned.
|
||||
"""
|
||||
out = {}
|
||||
|
||||
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
|
||||
|
||||
if pre_logits is not None:
|
||||
x = out["pre_logits"] = pre_logits
|
||||
logits = out["logits"] = embedder.decode(x)
|
||||
return logits, out
|
||||
|
||||
x = []
|
||||
if embedded_prefix is not None:
|
||||
x.append(embedded_prefix)
|
||||
if tokens is not None:
|
||||
x.append(embedder.encode(tokens))
|
||||
|
||||
x = jnp.concatenate(x, axis=-2)
|
||||
x = x.astype(self.embed_dtype)
|
||||
batch_size, seq_len, width = x.shape
|
||||
|
||||
if embed_only:
|
||||
return x
|
||||
|
||||
if decode:
|
||||
assert positions is not None and mask is not None, ( # noqa: PT018
|
||||
"Must explicitly pass positions and mask for decoding."
|
||||
)
|
||||
|
||||
if positions is None:
|
||||
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
|
||||
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
|
||||
|
||||
if mask is None:
|
||||
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
|
||||
if mask.ndim == 3:
|
||||
mask = mask[:, None, :, :]
|
||||
cache_size = max(seq_len, mask.shape[-1])
|
||||
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
|
||||
|
||||
if self.remat_policy == "none":
|
||||
block_cls = Block
|
||||
else:
|
||||
block_cls = nn.remat(
|
||||
Block,
|
||||
prevent_cse=not self.scan,
|
||||
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
|
||||
policy=getattr(jax.checkpoint_policies, self.remat_policy),
|
||||
)
|
||||
|
||||
block_kw = {
|
||||
"num_heads": self.num_heads,
|
||||
"head_dim": self.head_dim,
|
||||
"num_kv_heads": self.num_kv_heads,
|
||||
"embed_dim": width,
|
||||
"hidden_dim": self.mlp_dim,
|
||||
"dropout": self.dropout,
|
||||
"dropout_bdims": self.dropout_bdims,
|
||||
"cache_dtype": self.cache_dtype,
|
||||
"lora_configs": self.lora_configs,
|
||||
}
|
||||
layers = self.scope.push("layers")
|
||||
blocks = [
|
||||
nn.scan(
|
||||
block_cls,
|
||||
variable_axes={"params": 0},
|
||||
split_rngs={"params": True, "dropout": True},
|
||||
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
|
||||
length=self.depth,
|
||||
)(parent=layers, **block_kw)
|
||||
]
|
||||
for block in blocks:
|
||||
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
|
||||
|
||||
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
|
||||
out["encoded"] = x
|
||||
|
||||
x = RMSNorm(name="final_norm")(x)
|
||||
out["pre_logits"] = x
|
||||
if return_prelogits:
|
||||
return x, kv_cache, out
|
||||
|
||||
x = embedder.decode(x)
|
||||
out["logits"] = x
|
||||
|
||||
return x, kv_cache, out
|
||||
|
||||
def init(self):
|
||||
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
|
||||
self(jnp.zeros((1, 1), dtype=jnp.int32))
|
||||
|
||||
|
||||
def _apply_rope(x, *, positions, max_wavelength=10_000):
|
||||
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
|
||||
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
|
||||
timescale = max_wavelength**freq_exponents
|
||||
radians = positions[..., None] / timescale[None, None, :]
|
||||
radians = radians[..., None, :]
|
||||
assert radians.dtype == jnp.float32
|
||||
# radians.shape = [...,L,1,d=D/2]
|
||||
sin, cos = jnp.sin(radians), jnp.cos(radians)
|
||||
x1, x2 = jnp.split(x, 2, axis=-1)
|
||||
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
|
||||
assert res.dtype == jnp.float32
|
||||
return res
|
||||
148
policy/openpi-InternData-A1/src/openpi/models/lora.py
Normal file
148
policy/openpi-InternData-A1/src/openpi/models/lora.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import math
|
||||
import re
|
||||
|
||||
import flax.linen as nn
|
||||
import flax.struct as struct
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
# LoRA rank.
|
||||
rank: int
|
||||
# LoRA scaling factor.
|
||||
alpha: float = 1.0
|
||||
# Initialization function for LoRA parameters.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
|
||||
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
|
||||
rslora: bool = False
|
||||
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
|
||||
axes: tuple[int, int] = (-2, -1)
|
||||
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
|
||||
label: str = "L"
|
||||
|
||||
@property
|
||||
def scaling_value(self) -> float:
|
||||
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
|
||||
|
||||
|
||||
class Einsum(nn.Module):
|
||||
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
|
||||
|
||||
# Shape of the weight.
|
||||
shape: tuple[int, ...]
|
||||
# Initialization function for the weight.
|
||||
init_fn: nn.initializers.Initializer = nn.initializers.zeros
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w = self.param("w", self.init_fn, self.shape)
|
||||
|
||||
if config := self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
shape_a, shape_b = list(self.shape), list(self.shape)
|
||||
shape_a[config.axes[1]] = config.rank
|
||||
shape_b[config.axes[0]] = config.rank
|
||||
self.w_a = self.param("lora_a", config.init_fn, shape_a)
|
||||
self.w_b = self.param("lora_b", config.init_fn, shape_b)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, eqn: str, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
result = jnp.einsum(eqn, x, self.w.astype(dtype))
|
||||
|
||||
if config := self.lora_config:
|
||||
eqn_a, eqn_b = self._make_lora_eqns(eqn)
|
||||
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
|
||||
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
|
||||
result = result + lora * config.scaling_value
|
||||
|
||||
return result
|
||||
|
||||
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
|
||||
if "L" in eqn:
|
||||
raise ValueError(f"L already in eqn: {eqn}")
|
||||
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
|
||||
raise ValueError(f"Unsupported einsum eqn: {eqn}")
|
||||
lhs, rhs, out = m.groups()
|
||||
|
||||
assert self.lora_config is not None
|
||||
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
|
||||
label = self.lora_config.label
|
||||
|
||||
a_rhs = rhs.replace(b_label, label)
|
||||
a_out = out.replace(b_label, label)
|
||||
eqn_a = f"{lhs},{a_rhs}->{a_out}"
|
||||
|
||||
b_rhs = rhs.replace(a_label, label)
|
||||
eqn_b = f"{a_out},{b_rhs}->{out}"
|
||||
|
||||
return eqn_a, eqn_b
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
"""Feed forward module."""
|
||||
|
||||
features: int
|
||||
hidden_dim: int
|
||||
# If not None, apply LoRA to the weight.
|
||||
lora_config: LoRAConfig | None = None
|
||||
|
||||
def setup(self):
|
||||
self.w_gating = self.param(
|
||||
"gating_einsum",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
|
||||
(2, self.features, self.hidden_dim),
|
||||
)
|
||||
self.w_linear = self.param(
|
||||
"linear",
|
||||
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
|
||||
(self.hidden_dim, self.features),
|
||||
)
|
||||
self.w_gating_lora = None
|
||||
self.w_linear_lora = None
|
||||
if self.lora_config:
|
||||
# Setup LoRA parameters.
|
||||
# TODO: follow up with a simplified init_fn api.
|
||||
self.w_gating_lora = (
|
||||
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
|
||||
self.param(
|
||||
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
|
||||
),
|
||||
)
|
||||
self.w_linear_lora = (
|
||||
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
|
||||
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
|
||||
)
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
dtype = x.dtype # original dtype, could be half-precision
|
||||
ff_gate = self._dot(
|
||||
x,
|
||||
self.w_gating[0],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
|
||||
)
|
||||
gate_value = nn.gelu(ff_gate)
|
||||
|
||||
ff1 = self._dot(
|
||||
x,
|
||||
self.w_gating[1],
|
||||
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
|
||||
)
|
||||
activations = gate_value * ff1
|
||||
|
||||
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
|
||||
assert outputs.dtype == dtype
|
||||
return outputs
|
||||
|
||||
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
|
||||
base = jnp.dot(x, w.astype(x.dtype))
|
||||
if lora_weights is None:
|
||||
return base
|
||||
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))
|
||||
94
policy/openpi-InternData-A1/src/openpi/models/lora_test.py
Normal file
94
policy/openpi-InternData-A1/src/openpi/models/lora_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
import openpi.models.lora as lora
|
||||
|
||||
|
||||
def test_lora_einsum_params_shape():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
|
||||
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
# Ensure that lora parameters are not initialized when LoRA is not used.
|
||||
params = einsum.init(key, eqn, x)
|
||||
assert "lora_a" not in params["params"]
|
||||
assert "lora_b" not in params["params"]
|
||||
|
||||
# Check that default axes work.
|
||||
params_lora0 = lora0.init(key, eqn, x)
|
||||
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
|
||||
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
|
||||
|
||||
# Check that user provided axes work.
|
||||
params_lora1 = lora1.init(key, eqn, x)
|
||||
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
|
||||
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
|
||||
|
||||
|
||||
def test_lora_einsum_same_output():
|
||||
shape = (3, 8, 32, 4) # (3KDH)
|
||||
einsum = lora.Einsum(shape)
|
||||
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
|
||||
eqn = "BSD,3KDH->3BSKH"
|
||||
|
||||
params = einsum.init(key, eqn, x)
|
||||
output = einsum.apply(params, eqn, x)
|
||||
|
||||
params_lora = einsum_lora.init(key, eqn, x)
|
||||
output_lora = einsum_lora.apply(params_lora, eqn, x)
|
||||
|
||||
# Results are the same since the LoRA parameters are initialized to zeros.
|
||||
assert jnp.allclose(output, output_lora)
|
||||
|
||||
|
||||
def test_lora_ffn_params_shape():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params["params"]["linear"].shape == (32, 8)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
|
||||
assert params_lora["params"]["linear"].shape == (32, 8)
|
||||
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
|
||||
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
|
||||
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
|
||||
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
|
||||
|
||||
|
||||
def test_lora_ffn_same_output():
|
||||
ffn = lora.FeedForward(features=8, hidden_dim=32)
|
||||
ffn_lora = lora.FeedForward(
|
||||
features=8,
|
||||
hidden_dim=32,
|
||||
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
|
||||
)
|
||||
|
||||
key = jax.random.key(0)
|
||||
x = jax.random.normal(key, (2, 8))
|
||||
|
||||
params = ffn.init(key, x)
|
||||
output = ffn.apply(params, x)
|
||||
|
||||
params_lora = ffn_lora.init(key, x)
|
||||
output_lora = ffn_lora.apply(params_lora, x)
|
||||
|
||||
assert jnp.allclose(output, output_lora)
|
||||
332
policy/openpi-InternData-A1/src/openpi/models/model.py
Normal file
332
policy/openpi-InternData-A1/src/openpi/models/model.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import abc
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import enum
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import augmax
|
||||
from flax import nnx
|
||||
from flax import struct
|
||||
from flax import traverse_util
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import orbax.checkpoint as ocp
|
||||
import safetensors
|
||||
import torch
|
||||
|
||||
from openpi.models_pytorch import pi0_pytorch
|
||||
from openpi.shared import image_tools
|
||||
import openpi.shared.array_typing as at
|
||||
|
||||
logger = logging.getLogger("openpi")
|
||||
|
||||
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
|
||||
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
|
||||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
"""Supported model types."""
|
||||
|
||||
PI0 = "pi0"
|
||||
PI0_FAST = "pi0_fast"
|
||||
PI05 = "pi05"
|
||||
|
||||
|
||||
# The model always expects these images
|
||||
IMAGE_KEYS = (
|
||||
"base_0_rgb",
|
||||
"left_wrist_0_rgb",
|
||||
"right_wrist_0_rgb",
|
||||
)
|
||||
|
||||
|
||||
# This may need change if we release a small model.
|
||||
IMAGE_RESOLUTION = (224, 224)
|
||||
|
||||
|
||||
# Data format
|
||||
#
|
||||
# Data transforms produce the model input as a nested dictionary which is later converted
|
||||
# into `Obesrvation` and `Actions` objects. See below.
|
||||
#
|
||||
# In the dictory form, this data should look like:
|
||||
# {
|
||||
# # Observation data.
|
||||
# "image": {
|
||||
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
|
||||
# ... # Additional camera views
|
||||
# },
|
||||
# "image_mask": {
|
||||
# "base_0_rgb": bool[*b], # True if image is valid
|
||||
# ... # Masks for additional views
|
||||
# },
|
||||
# "state": float32[*b, s], # Low-dimensional robot state
|
||||
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
|
||||
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
|
||||
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
|
||||
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
|
||||
#
|
||||
# # Actions data.
|
||||
# "actions": float32[*b ah ad]
|
||||
# }
|
||||
# where:
|
||||
# *b = batch dimensions
|
||||
# h,w = image height/width
|
||||
# s = state dimension
|
||||
# l = sequence length
|
||||
#
|
||||
@at.typecheck
|
||||
@struct.dataclass
|
||||
class Observation(Generic[ArrayT]):
|
||||
"""Holds observations, i.e., inputs to the model.
|
||||
|
||||
See `Observation.from_dict` to see the expected dictionary form. This is the format
|
||||
that should be produced by the data transforms.
|
||||
"""
|
||||
|
||||
# Images, in [-1, 1] float32.
|
||||
images: dict[str, at.Float[ArrayT, "*b h w c"]]
|
||||
# Image masks, with same keys as images.
|
||||
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
|
||||
# Low-dimensional robot state.
|
||||
state: at.Float[ArrayT, "*b s"]
|
||||
|
||||
# Tokenized prompt.
|
||||
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Tokenized prompt mask.
|
||||
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
# pi0-fast model specific fields.
|
||||
|
||||
# Token auto-regressive mask (for FAST autoregressive model).
|
||||
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
|
||||
# Token loss mask (for FAST autoregressive model).
|
||||
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
|
||||
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
|
||||
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
|
||||
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
|
||||
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
|
||||
# If images are uint8, convert them to [-1, 1] float32.
|
||||
for key in data["image"]:
|
||||
if data["image"][key].dtype == np.uint8:
|
||||
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
|
||||
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
|
||||
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
|
||||
return cls(
|
||||
images=data["image"],
|
||||
image_masks=data["image_mask"],
|
||||
state=data["state"],
|
||||
tokenized_prompt=data.get("tokenized_prompt"),
|
||||
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
|
||||
token_ar_mask=data.get("token_ar_mask"),
|
||||
token_loss_mask=data.get("token_loss_mask"),
|
||||
)
|
||||
|
||||
def to_dict(self) -> at.PyTree[ArrayT]:
|
||||
"""Convert the Observation to a nested dict."""
|
||||
result = dataclasses.asdict(self)
|
||||
result["image"] = result.pop("images")
|
||||
result["image_mask"] = result.pop("image_masks")
|
||||
return result
|
||||
|
||||
|
||||
# Defines the format of the actions. This field is included as "actions" inside the dictionary
|
||||
# produced by the data transforms.
|
||||
Actions = at.Float[ArrayT, "*b ah ad"]
|
||||
|
||||
|
||||
def preprocess_observation(
|
||||
rng: at.KeyArrayLike | None,
|
||||
observation: Observation,
|
||||
*,
|
||||
train: bool = False,
|
||||
image_keys: Sequence[str] = IMAGE_KEYS,
|
||||
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
|
||||
) -> Observation:
|
||||
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
|
||||
filling in a default image mask (if necessary).
|
||||
"""
|
||||
|
||||
if not set(image_keys).issubset(observation.images):
|
||||
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
|
||||
|
||||
batch_shape = observation.state.shape[:-1]
|
||||
|
||||
out_images = {}
|
||||
for key in image_keys:
|
||||
image = observation.images[key]
|
||||
if image.shape[1:3] != image_resolution:
|
||||
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
|
||||
image = image_tools.resize_with_pad(image, *image_resolution)
|
||||
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for augmax.
|
||||
image = image / 2.0 + 0.5
|
||||
|
||||
transforms = []
|
||||
if "wrist" not in key:
|
||||
height, width = image.shape[1:3]
|
||||
transforms += [
|
||||
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
|
||||
augmax.Resize(width, height),
|
||||
augmax.Rotate((-5, 5)),
|
||||
]
|
||||
transforms += [
|
||||
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
|
||||
]
|
||||
sub_rngs = jax.random.split(rng, image.shape[0])
|
||||
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
|
||||
|
||||
# Back to [-1, 1].
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
out_images[key] = image
|
||||
|
||||
# obtain mask
|
||||
out_masks = {}
|
||||
for key in out_images:
|
||||
if key not in observation.image_masks:
|
||||
# do not mask by default
|
||||
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
|
||||
else:
|
||||
out_masks[key] = jnp.asarray(observation.image_masks[key])
|
||||
|
||||
return Observation(
|
||||
images=out_images,
|
||||
image_masks=out_masks,
|
||||
state=observation.state,
|
||||
tokenized_prompt=observation.tokenized_prompt,
|
||||
tokenized_prompt_mask=observation.tokenized_prompt_mask,
|
||||
token_ar_mask=observation.token_ar_mask,
|
||||
token_loss_mask=observation.token_loss_mask,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class BaseModelConfig(abc.ABC):
|
||||
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
|
||||
method to create the corresponding model.
|
||||
"""
|
||||
|
||||
# Action space dimension.
|
||||
action_dim: int
|
||||
# Action sequence length.
|
||||
action_horizon: int
|
||||
# Tokenized prompt maximum length.
|
||||
max_token_len: int
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_type(self) -> ModelType:
|
||||
"""The model type."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
|
||||
"""Create a new model, initializing parameters."""
|
||||
|
||||
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
|
||||
"""Create a model with the given parameters."""
|
||||
model = nnx.eval_shape(self.create, jax.random.key(0))
|
||||
graphdef, state = nnx.split(model)
|
||||
if remove_extra_params:
|
||||
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
|
||||
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
|
||||
state.replace_by_pure_dict(params)
|
||||
return nnx.merge(graphdef, state)
|
||||
|
||||
def load_pytorch(self, train_config, weight_path: str):
|
||||
logger.info(f"train_config: {train_config}")
|
||||
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
|
||||
safetensors.torch.load_model(model, weight_path)
|
||||
return model
|
||||
|
||||
@abc.abstractmethod
|
||||
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
|
||||
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
|
||||
|
||||
def fake_obs(self, batch_size: int = 1) -> Observation:
|
||||
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
|
||||
|
||||
def fake_act(self, batch_size: int = 1) -> Actions:
|
||||
_, action_spec = self.inputs_spec(batch_size=batch_size)
|
||||
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BaseModel(nnx.Module, abc.ABC):
|
||||
"""Base class for all model implementations. Specific models should inherit from this class. They should call
|
||||
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
|
||||
"""
|
||||
|
||||
action_dim: int
|
||||
action_horizon: int
|
||||
max_token_len: int
|
||||
|
||||
@abc.abstractmethod
|
||||
def compute_loss(
|
||||
self,
|
||||
rng: at.KeyArrayLike,
|
||||
observation: Observation,
|
||||
actions: Actions,
|
||||
*,
|
||||
train: bool = False,
|
||||
) -> at.Float[at.Array, "*b ah"]: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
|
||||
|
||||
|
||||
def restore_params(
|
||||
params_path: pathlib.Path | str,
|
||||
*,
|
||||
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
|
||||
dtype: jnp.dtype | None = None,
|
||||
sharding: jax.sharding.Sharding | None = None,
|
||||
) -> at.Params:
|
||||
"""Restores unstructured params PyTree from a checkpoint.
|
||||
|
||||
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
|
||||
well as pre-trained checkpoints released for openpi.
|
||||
|
||||
Args:
|
||||
params_path: The local path to the checkpoint directory.
|
||||
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
|
||||
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
|
||||
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
|
||||
|
||||
Returns:
|
||||
The restored params.
|
||||
"""
|
||||
params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
|
||||
|
||||
if restore_type is jax.Array and sharding is None:
|
||||
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
|
||||
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
|
||||
|
||||
with ocp.PyTreeCheckpointer() as ckptr:
|
||||
metadata = ckptr.metadata(params_path)
|
||||
item = {"params": metadata["params"]}
|
||||
|
||||
params = ckptr.restore(
|
||||
params_path,
|
||||
ocp.args.PyTreeRestore(
|
||||
item=item,
|
||||
restore_args=jax.tree.map(
|
||||
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
|
||||
),
|
||||
),
|
||||
)["params"]
|
||||
|
||||
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
|
||||
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
|
||||
flat_params = traverse_util.flatten_dict(params)
|
||||
if all(kp[-1] == "value" for kp in flat_params):
|
||||
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
|
||||
return traverse_util.unflatten_dict(flat_params)
|
||||
94
policy/openpi-InternData-A1/src/openpi/models/model_test.py
Normal file
94
policy/openpi-InternData-A1/src/openpi/models/model_test.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from flax import nnx
|
||||
import jax
|
||||
import pytest
|
||||
|
||||
from openpi.models import model as _model
|
||||
from openpi.models import pi0_config
|
||||
from openpi.models import pi0_fast
|
||||
from openpi.shared import download
|
||||
from openpi.shared import nnx_utils
|
||||
|
||||
|
||||
def test_pi0_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
|
||||
|
||||
def test_pi0_fast_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig()
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
|
||||
def test_pi0_fast_lora_model():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
|
||||
model = config.create(key)
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
|
||||
assert loss.shape == (batch_size,)
|
||||
|
||||
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
|
||||
assert actions.shape == (batch_size, 256)
|
||||
|
||||
lora_filter = nnx_utils.PathRegex(".*lora.*")
|
||||
model_state = nnx.state(model)
|
||||
|
||||
lora_state_elems = list(model_state.filter(lora_filter))
|
||||
assert len(lora_state_elems) > 0
|
||||
|
||||
|
||||
@pytest.mark.manual
|
||||
def test_model_restore():
|
||||
key = jax.random.key(0)
|
||||
config = pi0_config.Pi0Config()
|
||||
|
||||
batch_size = 2
|
||||
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
|
||||
|
||||
model = config.load(
|
||||
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
|
||||
)
|
||||
|
||||
loss = model.compute_loss(key, obs, act)
|
||||
assert loss.shape == (batch_size, config.action_horizon)
|
||||
|
||||
actions = model.sample_actions(key, obs, num_steps=10)
|
||||
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user