multi-node openpi commit

This commit is contained in:
Leon998
2026-03-17 23:05:23 +08:00
parent 28833f0c0f
commit 7411e0e004
156 changed files with 33951 additions and 1 deletions

Submodule policy/openpi-InternData-A1 deleted from 10b4b8fd13

View File

@@ -0,0 +1,3 @@
.venv
checkpoints
data

169
policy/openpi-InternData-A1/.gitignore vendored Normal file
View File

@@ -0,0 +1,169 @@
# Data directories.
assets/
checkpoints/
data/
wandb/
third_party/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

View File

@@ -0,0 +1,16 @@
exclude: third_party/
repos:
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.5.14
hooks:
- id: uv-lock
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.6
hooks:
# Run the linter.
- id: ruff
args: [--fix]
- id: ruff-format

View File

@@ -0,0 +1 @@
3.11

View File

@@ -0,0 +1,33 @@
# Contributing to openpi
We welcome contributions, improvements, and modifications. Everyone is welcome to use openpi in accordance to the [license](LICENSE). Contributors are also welcome to submit bug reports, feature requests, and pull requests. We can't promise to approve every pull request, and we are a small team with limited bandwidth to review all requests, but we'll give it our best effort. Specifics are described below.
## Issues and feature requests
You are welcome to use the Github [discussion](https://github.com/Physical-Intelligence/openpi/discussions) feature if you would like to discuss something that is not directly reporting an issue or making a feature request. This is suitable for questions about how to use some aspect of openpi, or other topics.
If you found a bug or other issue, please first check that the issue was not already reported (use the search bar on Github under Issues). If the issue has not yet been reported, please include this information when filing a Github issue:
- Your OS type and version and the version of Python you are using
- Code that allows us to reproduce your bug, including all dependencies
- Traceback of any exception
- Any other information that would help us, such as a screenshot
In order for us to address any issue, we must be able to reproduce it, so if you encountered the issue after making modifications to openpi, please reproduce the issue without any other modifications and provide a code snippet that allows us to quickly reproduce the problem on `main`.
If you would like to submit a feature request, please check that the feature request does not already exist, and please provide the following information:
- The motivation for the feature
- A description of the problem you are trying to solve or your use case
- Enough information for us to understand the nature of the request
- Some information for how you intend to use it (this might help us in understanding the motivation!)
We can't promise to support every feature request, but it is helpful to us to know the use cases that you are interested in!
## Submitting a pull request
If you implemented support for a new robot or environment, or some other new feature, we welcome pull requests (PRs) to openpi. We encourage you to create a [feature request](https://github.com/Physical-Intelligence/openpi/issues) or make a post on the [discussion](https://github.com/Physical-Intelligence/openpi/discussions) board before starting to work on your PR, if you would like to get a sense for whether we are likely to approve your PR if it is submitted. Since we are a small team with limited ability to provide maintenance and support, we may not accept all PRs (e.g., if we believe it would make the code harder to maintain, or if reviewing the PR is out of scope for us), so contacting us in advance is a good way to get a sense for whether your PR is likely to get approved for merging into openpi directly. But even if it isn't, you are of course more than welcome to maintain your own fork with whatever modifications you would like. When creating PRs, we recommend every contribution to consider the following:
- Make sure that your PR has a clear title and description
- Run `pre-commit` (install using `pre-commit install` first), and run `ruff check .` and `ruff format .`
- Make sure your PR passes all tests

View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,9 @@
# openpi-InternData-A1
## Training
For detailed instructions on pretraining with InterData-A1, finetuning on real-world tasks and sim2real transfer experiments, please refer to [`docs/training.md`](docs/training.md).
## Pretrained Checkpoints
We pretrained Pi0 model in on InternData-A1 for 680k iterations, initialized from PaliGemma checkpoint. The resulting pretrained ckpt is available [here](https://huggingface.co/yuyinyang3y/interndata-a1).

View File

@@ -0,0 +1,25 @@
### Docker Setup
All of the examples in this repo provide instructions for being run normally, and also using Docker. Although not required, the Docker option is recommended as this will simplify software installation, produce a more stable environment, and also allow you to avoid installing ROS and cluttering your machine, for examples which depend on ROS.
- Basic Docker installation instructions are [here](https://docs.docker.com/engine/install/).
- Docker must be installed in [rootless mode](https://docs.docker.com/engine/security/rootless/).
- To use your GPU you must also install the [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
- The version of docker installed with `snap` is incompatible with the NVIDIA container toolkit, preventing it from accessing `libnvidia-ml.so` ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/154)). The snap version can be uninstalled with `sudo snap remove docker`.
- Docker Desktop is also incompatible with the NVIDIA runtime ([issue](https://github.com/NVIDIA/nvidia-container-toolkit/issues/229)). Docker Desktop can be uninstalled with `sudo apt remove docker-desktop`.
If starting from scratch and your host machine is Ubuntu 22.04, you can use accomplish all of the above with the convenience scripts `scripts/docker/install_docker_ubuntu22.sh` and `scripts/docker/install_nvidia_container_toolkit.sh`.
Build the Docker image and start the container with the following command:
```bash
docker compose -f scripts/docker/compose.yml up --build
```
To build and run the Docker image for a specific example, use the following command:
```bash
docker compose -f examples/<example_name>/compose.yml up --build
```
where `<example_name>` is the name of the example you want to run.
During the first run of any example, Docker will build the images. Go grab a coffee while this happens. Subsequent runs will be faster since the images are cached.

View File

@@ -0,0 +1,179 @@
# Normalization Statistics
Here we provide instructions for computing **normalization statistics** for both **real-world**, **simulation (InternData-A1)** and **sim2real** tasks. The computed statistics are saved in JSON format and are intended to be reused during training and evaluation in the OpenPI pipeline.
Normalization is computed over:
- `state`
- `actions`
and follows the exact data preprocessing and repacking logic used during training.
---
## 1. Simulation Tasks (InternData-A1)
This script `scripts/compute_norm_stats_sim.py` computes normalization statistics for simulation tasks in the InternData-A1 benchmark.
### Supported Robots
- `split_aloha`
- `lift2`
- `genie1`
- `franka`
### Dataset Structure
Download the InternData-A1 datasets from [here](https://huggingface.co/datasets/InternRobotics/InternData-A1).
The structure of the dataset is as follows:
```
InternData-A1/sim/
└── <task_category>/
└── <robot_name>/
└── <task_name>/ # no subtask
├── data/
├── meta/
└── videos/
```
Some tasks may have subtasks / collections:
```
InternData-A1/sim/
└── <task_category>/
└── <robot_name>/
└── <task_name>/
└── <collect_name>/
├── data/
├── meta/
└── videos/
```
### Usage
```
python scripts/compute_norm_stats_sim.py \
--root_data_dir InternData-A1/sim \
--task_category pick_and_place_tasks \
--save_path stats/sim \
--start_ratio 0.0 \
--end_ratio 1.0
```
Arguments
- `root_data_dir`: Root directory of simulation datasets.
- `task_category`: Task category to process (e.g. pick_and_place_tasks).
- `save_path`: Root directory where normalization statistics will be saved.
- `start_ratio`, `end_ratio`: Fraction of tasks to process (useful for sharding large datasets).
### Output Structure
```
<save_path>/
└── <task_category>/
└── <robot_name>/
└── <task_name>/
└── <collect_name>/ # empty if no subtask
└── norm_stats.json
```
During pretraining, set the `stats_dir` argument in `DataConfig` to the `save_path` here.
## 2. Real-World Tasks
This script `scripts/compute_norm_stats_real.py` computes normalization statistics for real-world tasks.
### Supported Robots
- `lift2`
- `split_aloha`
- `acone`
- `genie1`
### Dataset Structure
Real-world datasets are expected to follow the LeRobot repository structure:
```
InternData-A1/real/
└── <robot_name>/
└── <task_name>/
└── <collect_name>/ # empty if no subtask
├── data/
├── meta/
└── videos/
```
Example task path:
```
InternData-A1/real/genie1/
└── Pick_a_bag_of_bread_with_the_left_arm__then_handover/set_0
```
### Usage
```
python scripts/compute_norm_stats_real.py \
--task_path InternData-A1/real/genie1/Pick_a_bag_of_bread_with_the_left_arm__then_handover/* \
--robot_name genie1 \
--save_path stats/real
```
Arguments
- `task_path`: Path (or glob pattern) to a real-world task dataset(e.g. `InternData-A1/real/genie1/Pick_a_bag_of_bread_with_the_left_arm__then_handover/*`)
- `robot_name`: Robot platform name (must be supported).
- `save_path`: Root directory where normalization statistics will be saved.
### Output Structure
```
<save_path>/
└── <robot_name>/
└── <task_name>/
└── norm_stats.json
```
During finetuning, set the `fixed_stats_dir` argument in `DataConfig` to `<save_path>/<robot_name>/<task_name>` here.
## 3. Sim2Real Experiments
This script `scripts/compute_norm_stats_sim2real.py` computes normalization statistics for sim2real experiments.
### Supported Robots
- `lift2`
### Dataset Structure
Dataset from InternData-A1 are expected to follow the LeRobot repository structure:
```
InternData-A1/sim/
└── <task_category>/
└── <robot_name>/
└── <task_name>/
└── <collect_name>/
├── data/
├── meta/
└── videos/
```
Example task path:
```
InternData-A1/sim/long_horizon_tasks/lift2/
└── sort_the_rubbish
└── Sort_rubbish_1l2r
└── Sort_rubbish_2l1r
└── Sort_rubbish_2l2r
```
### Usage
```
python scripts/compute_norm_stats_sim2real.py \
--task_path InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/* \
--robot_name lift2 \
--save_path stats/sim2real
```
Arguments
- `task_path`: Path (or glob pattern) to a task dataset(e.g. `InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/*` means training on all the collections in the task)
- `robot_name`: Robot platform name (we only support `lift2` for now, but you can try other robots).
- `save_path`: Root directory where normalization statistics will be saved.
### Output Structure
```
<save_path>/
└── <robot_name>/
└── <task_name>/
└── norm_stats.json
```
During finetuning, set the `fixed_stats_dir` argument in `DataConfig` to `<save_path>/<robot_name>/<task_name>` here.
## Implementation Notes
For simulation tasks and sim2real experiments, computation may stop early (e.g. after 10k steps) to limit runtime.
For sim2real transfer, we set the gripper dimension in the state vector to zero because the state of the gripper in the real world during inference is not aligned with the state in the simulation. See `src/openpi/policies/sim2real_split_aloha_policy.py` for more details.

View File

@@ -0,0 +1,71 @@
# Running openpi models remotely
We provide utilities for running openpi models remotely. This is useful for running inference on more powerful GPUs off-robot, and also helps keep the robot and policy environments separate (and e.g. avoid dependency hell with robot software).
## Starting a remote policy server
To start a remote policy server, you can simply run the following command:
```bash
uv run scripts/serve_policy.py --env=[DROID | ALOHA | LIBERO]
```
The `env` argument specifies which $\pi_0$ checkpoint should be loaded. Under the hood, this script will execute a command like the following, which you can use to start a policy server, e.g. for checkpoints you trained yourself (here an example for the DROID environment):
```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
```
This will start a policy server that will serve the policy specified by the `config` and `dir` arguments. The policy will be served on the specified port (default: 8000).
## Querying the remote policy server from your robot code
We provide a client utility with minimal dependencies that you can easily embed into any robot codebase.
First, install the `openpi-client` package in your robot environment:
```bash
cd $OPENPI_ROOT/packages/openpi-client
pip install -e .
```
Then, you can use the client to query the remote policy server from your robot code. Here's an example of how to do this:
```python
from openpi_client import image_tools
from openpi_client import websocket_client_policy
# Outside of episode loop, initialize the policy client.
# Point to the host and port of the policy server (localhost and 8000 are the defaults).
client = websocket_client_policy.WebsocketClientPolicy(host="localhost", port=8000)
for step in range(num_steps):
# Inside the episode loop, construct the observation.
# Resize images on the client side to minimize bandwidth / latency. Always return images in uint8 format.
# We provide utilities for resizing images + uint8 conversion so you match the training routines.
# The typical resize_size for pre-trained pi0 models is 224.
# Note that the proprioceptive `state` can be passed unnormalized, normalization will be handled on the server side.
observation = {
"observation/image": image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, 224, 224)
),
"observation/wrist_image": image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, 224, 224)
),
"observation/state": state,
"prompt": task_instruction,
}
# Call the policy server with the current observation.
# This returns an action chunk of shape (action_horizon, action_dim).
# Note that you typically only need to call the policy every N steps and execute steps
# from the predicted action chunk open-loop in the remaining steps.
action_chunk = client.infer(observation)["actions"]
# Execute the actions in the environment.
...
```
Here, the `host` and `port` arguments specify the IP address and port of the remote policy server. You can also specify these as command-line arguments to your robot code, or hard-code them in your robot codebase. The `observation` is a dictionary of observations and the prompt, following the specification of the policy inputs for the policy you are serving. We have concrete examples of how to construct this dictionary for different environments in the [simple client example](examples/simple_client/main.py).

View File

@@ -0,0 +1,102 @@
# Training Instructions
Here we provide instructions for pretraining on InternData-A1, finetuning on real-world tasks and finetuning on InternData-A1 tasks for sim2real transfer.
Before training, you need to compute the normalization statistics for the tasks you want to train on. Please refer to [norm_stats.md](norm_stats.md) for more details.
---
## 1. Pretraining on InternData-A1
### Write a training config
We provide a `TrainConfig` example named `pretrain-interndata-a1` in `src/openpi/training/config.py`.
InternData-A1 contains four robot embodiments:
- `split_aloha`
- `lift2`
- `genie1`
- `franka`
Accordingly, we define three `MultiDataConfigFactory` classes:
- `MultiSimSplitAlohaDataConfig` for `split_aloha` and `lift2`
- `MultiSimGenieDataConfig` for `genie1`
- `MultiSimFrankaDataConfig` for `franka`
Please either:
- create a soft link from the InternData-A1 dataset to `data/InternData-A1`, or
- modify the `repo_dir` field in all relevant `MultiDataConfig` entries to point to your local InternData-A1 path.
Set `stats_dir` to your local normalization statistics directory. If you use the default setting, ensure that the normalization statistics for simulation tasks are saved under `stats/sim`.
We initialize the model from PaliGemma-3B using:
```
weight_loader=weight_loaders.PaliGemmaWeightLoader("checkpoints/jax/paligemma/pt_224.npz")
```
Please download the PaliGemma-3b checkpoint by running
```
python scripts/download_paligemma.py
```
You may adjust other training parameters based on your available GPUs and training budget:
- `num_train_steps`: Total number of training steps
- `num_workers`: Number of data loading workers
- `fsdp_devices`: Number of GPUs per node
- `batch_size`: Batch size per GPU
- `save_interval`: Checkpoint saving interval (in steps)
### Run training
For multi node training, run
```
bash scripts/training_scripts/multi_node.sh
```
For single node multi-GPU training, run
```
config_name=pretrain-interndata-a1
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
```
The ckpts will be saved to `checkpoints/${config_name}`.
## 2. Finetuning on Real-World Tasks
### Write a training config
We provide a `TrainConfig` example named `finetune-a2d-pen` in `src/openpi/training/config.py`.
Key arguments you may need to modify include:
- `MultiDataConfigFactory` class:
- `MultiLeRobotReala2dDataConfig` for `genie1`
- `MultiLeRobotRealArxLift2DataConfig` for `lift2` and `acone`
- `repo_dir`: Path to the real-world task dataset.
- `robot_name`: the robot name in `repo_dir`, e.g. "genie1".
- `fixed_stats_dir`: Path to the normalization statistics for the real-world task. When this is set, statistics from `stats_dir` will not be used.
- `weight_loader`: Pretrained checkpoint used for initialization.
You may download our pretrained checkpoints from [here]().
### Run training
For training, run
For single node multi-GPU training, run
```
config_name=finetune-a2d-pen
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
```
The ckpts will be saved under `checkpoints/${config_name}`.
## 3. Finetuning on InternData-A1 Tasks for Sim2Real Transfer
### Write a training config
We provide a `TrainConfig` example named `finetune-sim2real-lift2-sort-rubbish` in `src/openpi/training/config.py`.
Key arguments you may need to modify include:
- `MultiDataConfigFactory` class: Currently, sim-to-real transfer is evaluated only on `lift2` tasks:
- `MultiSim2RealSplitAlohaDataConfig` for `lift2`
- `repo_dir`: Path to the corresponding InternData-A1 task.
- `fixed_stats_dir`: Path to the normalization statistics for the sim-to-real task. When specified, statistics from `stats_dir` will not be used.
- `weight_loader`: Pretrained checkpoint used for initialization.
### Run training
For training, run
For single node multi-GPU training, run
```
config_name=finetune-sim2real-lift2-sort-rubbish
bash scripts/training_scripts/single_node_multi_gpu.sh ${config_name}
```

View File

@@ -0,0 +1,70 @@
# Dockerfile for the Aloha real environment.
# Build the container:
# docker build . -t aloha_real -f examples/aloha_real/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v /dev:/dev -v .:/app --privileged aloha_real /bin/bash
FROM ros:noetic-robot@sha256:7cf0b9f6546abeba308ea42cb7ad3453f3e520e1af57cdf179fe915c939674bc
SHELL ["/bin/bash", "-c"]
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y --no-install-recommends \
cmake \
curl \
libffi-dev \
python3-rosdep \
python3-rosinstall \
python3-rosinstall-generator \
whiptail \
git \
wget \
openssh-client \
ros-noetic-cv-bridge \
ros-noetic-usb-cam \
ros-noetic-realsense2-camera \
keyboard-configuration
WORKDIR /root
RUN curl 'https://raw.githubusercontent.com/Interbotix/interbotix_ros_manipulators/main/interbotix_ros_xsarms/install/amd64/xsarm_amd64_install.sh' > xsarm_amd64_install.sh
RUN chmod +x xsarm_amd64_install.sh
RUN export TZ='America/Los_Angeles' && ./xsarm_amd64_install.sh -d noetic -n
COPY ./third_party/aloha /root/interbotix_ws/src/aloha
RUN cd /root/interbotix_ws && source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && catkin_make
# Install python 3.10 because this ROS image comes with 3.8
RUN mkdir /python && \
cd /python && \
wget https://www.python.org/ftp/python/3.10.14/Python-3.10.14.tgz && \
tar -zxvf Python-3.10.14.tgz && \
cd Python-3.10.14 && \
ls -lhR && \
./configure --enable-optimizations && \
make install && \
echo 'alias python3="/usr/local/bin/python3.10"' >> ~/.bashrc && \
echo 'alias python="/usr/local/bin/python3.10"' >> ~/.bashrc && \
cd ~ && rm -rf /python && \
rm -rf /var/lib/apt/lists/*
COPY --from=ghcr.io/astral-sh/uv:0.5.6 /uv /bin/uv
ENV UV_HTTP_TIMEOUT=120
ENV UV_LINK_MODE=copy
COPY ./examples/aloha_real/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
RUN uv pip sync --python 3.10 --system /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src:/root/interbotix_ws/src/aloha/aloha_scripts:/root/interbotix_ws/src/aloha
WORKDIR /app
# Create an entrypoint script to run the setup commands, followed by the command passed in.
RUN cat <<'EOF' > /usr/local/bin/entrypoint.sh
#!/bin/bash
source /opt/ros/noetic/setup.sh && source /root/interbotix_ws/devel/setup.sh && "$@"
EOF
RUN chmod +x /usr/local/bin/entrypoint.sh
ENTRYPOINT ["/usr/local/bin/entrypoint.sh"]
CMD ["python3", "/app/examples/aloha_real/main.py"]

View File

@@ -0,0 +1,126 @@
# Run Aloha (Real Robot)
This example demonstrates how to run with a real robot using an [ALOHA setup](https://github.com/tonyzhaozh/aloha). See [here](../../docs/remote_inference.md) for instructions on how to load checkpoints and run inference. We list the relevant checkpoint paths for each provided fine-tuned model below.
## Prerequisites
This repo uses a fork of the ALOHA repo, with very minor modifications to use Realsense cameras.
1. Follow the [hardware installation instructions](https://github.com/tonyzhaozh/aloha?tab=readme-ov-file#hardware-installation) in the ALOHA repo.
1. Modify the `third_party/aloha/aloha_scripts/realsense_publisher.py` file to use serial numbers for your cameras.
## With Docker
```bash
export SERVER_ARGS="--env ALOHA --default_prompt='take the toast out of the toaster'"
docker compose -f examples/aloha_real/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_real/.venv
source examples/aloha_real/.venv/bin/activate
uv pip sync examples/aloha_real/requirements.txt
uv pip install -e packages/openpi-client
# Run the robot
python -m examples.aloha_real.main
```
Terminal window 2:
```bash
roslaunch aloha ros_nodes.launch
```
Terminal window 3:
```bash
uv run scripts/serve_policy.py --env ALOHA --default_prompt='take the toast out of the toaster'
```
## **ALOHA Checkpoint Guide**
The `pi0_base` model can be used in zero shot for a simple task on the ALOHA platform, and we additionally provide two example fine-tuned checkpoints, “fold the towel” and “open the tupperware and put the food on the plate,” which can perform more advanced tasks on the ALOHA.
While weve found the policies to work in unseen conditions across multiple ALOHA stations, we provide some pointers here on how best to set up scenes to maximize the chance of policy success. We cover the prompts to use for the policies, objects weve seen it work well on, and well-represented initial state distributions. Running these policies in zero shot is still a very experimental feature, and there is no guarantee that they will work on your robot. The recommended way to use `pi0_base` is by finetuning with data from the target robot.
---
### **Toast Task**
This task involves the robot taking two pieces of toast out of a toaster and placing them on a plate.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_base`
- **Prompt**: "take the toast out of the toaster"
- **Objects needed**: Two pieces of toast, a plate, and a standard toaster.
- **Object Distribution**:
- Works on both real toast and rubber fake toast
- Compatible with standard 2-slice toasters
- Works with plates of varying colors
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 06 02 PM" src="https://github.com/user-attachments/assets/3d043d95-9d1c-4dda-9991-e63cae61e02e" />
- The toaster should be positioned in the top-left quadrant of the workspace.
- Both pieces of toast should start inside the toaster, with at least 1 cm of bread sticking out from the top.
- The plate should be placed roughly in the lower-center of the workspace.
- Works with both natural and synthetic lighting, but avoid making the scene too dark (e.g., don't place the setup inside an enclosed space or under a curtain).
### **Towel Task**
This task involves folding a small towel (e.g., roughly the size of a hand towel) into eighths.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_towel`
- **Prompt**: "fold the towel"
- **Object Distribution**:
- Works on towels of varying solid colors
- Performance is worse on heavily textured or striped towels
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 01 15 PM" src="https://github.com/user-attachments/assets/9410090c-467d-4a9c-ac76-96e5b4d00943" />
- The towel should be flattened and roughly centered on the table.
- Choose a towel that does not blend in with the table surface.
### **Tupperware Task**
This task involves opening a tupperware filled with food and pouring the contents onto a plate.
- **Checkpoint path**: `gs://openpi-assets/checkpoints/pi0_aloha_tupperware`
- **Prompt**: "open the tupperware and put the food on the plate"
- **Objects needed**: Tupperware, food (or food-like items), and a plate.
- **Object Distribution**:
- Works on various types of fake food (e.g., fake chicken nuggets, fries, and fried chicken).
- Compatible with tupperware of different lid colors and shapes, with best performance on square tupperware with a corner flap (see images below).
- The policy has seen plates of varying solid colors.
### **Scene Setup Guidelines**
<img width="500" alt="Screenshot 2025-01-31 at 10 02 27 PM" src="https://github.com/user-attachments/assets/60fc1de0-2d64-4076-b903-f427e5e9d1bf" />
- Best performance observed when both the tupperware and plate are roughly centered in the workspace.
- Positioning:
- Tupperware should be on the left.
- Plate should be on the right or bottom.
- The tupperware flap should point toward the plate.
## Training on your own Aloha dataset
1. Convert the dataset to the LeRobot dataset v2.0 format.
We provide a script [convert_aloha_data_to_lerobot.py](./convert_aloha_data_to_lerobot.py) that converts the dataset to the LeRobot dataset v2.0 format. As an example we have converted the `aloha_pen_uncap_diverse_raw` dataset from the [BiPlay repo](https://huggingface.co/datasets/oier-mees/BiPlay/tree/main/aloha_pen_uncap_diverse_raw) and uploaded it to the HuggingFace Hub as [physical-intelligence/aloha_pen_uncap_diverse](https://huggingface.co/datasets/physical-intelligence/aloha_pen_uncap_diverse).
2. Define a training config that uses the custom dataset.
We provide the [pi0_aloha_pen_uncap config](../../src/openpi/training/config.py) as an example. You should refer to the root [README](../../README.md) for how to run training with the new config.
IMPORTANT: Our base checkpoint includes normalization stats from various common robot configurations. When fine-tuning a base checkpoint with a custom dataset from one of these configurations, we recommend using the corresponding normalization stats provided in the base checkpoint. In the example, this is done by specifying the trossen asset_id and a path to the pretrained checkpoints asset directory within the AssetsConfig.

View File

@@ -0,0 +1,66 @@
# Run with:
# docker compose -f examples/aloha_real/compose.yml up --build
services:
runtime:
image: aloha_real
depends_on:
- aloha_ros_nodes
- ros_master
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
aloha_ros_nodes:
image: aloha_real
depends_on:
- ros_master
build:
context: ../..
dockerfile: examples/aloha_real/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- /dev:/dev
command: roslaunch --wait aloha ros_nodes.launch
ros_master:
image: ros:noetic-robot
network_mode: host
privileged: true
command:
- roscore
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,71 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
### Task parameters
### ALOHA fixed constants
DT = 0.001
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
# Left finger position limits (qpos[7]), right_finger = -1 * left_finger
MASTER_GRIPPER_POSITION_OPEN = 0.02417
MASTER_GRIPPER_POSITION_CLOSE = 0.01244
PUPPET_GRIPPER_POSITION_OPEN = 0.05800
PUPPET_GRIPPER_POSITION_CLOSE = 0.01844
# Gripper joint limits (qpos[6])
MASTER_GRIPPER_JOINT_OPEN = 0.3083
MASTER_GRIPPER_JOINT_CLOSE = -0.6842
PUPPET_GRIPPER_JOINT_OPEN = 1.4910
PUPPET_GRIPPER_JOINT_CLOSE = -0.6213
############################ Helper functions ############################
MASTER_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_POSITION_CLOSE) / (
MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_POSITION_CLOSE) / (
PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE) + MASTER_GRIPPER_POSITION_CLOSE
)
PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE) + PUPPET_GRIPPER_POSITION_CLOSE
)
MASTER2PUPPET_POSITION_FN = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(MASTER_GRIPPER_POSITION_NORMALIZE_FN(x))
MASTER_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - MASTER_GRIPPER_JOINT_CLOSE) / (
MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_NORMALIZE_FN = lambda x: (x - PUPPET_GRIPPER_JOINT_CLOSE) / (
PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE) + MASTER_GRIPPER_JOINT_CLOSE
)
PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN = (
lambda x: x * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE) + PUPPET_GRIPPER_JOINT_CLOSE
)
MASTER2PUPPET_JOINT_FN = lambda x: PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(MASTER_GRIPPER_JOINT_NORMALIZE_FN(x))
MASTER_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (MASTER_GRIPPER_POSITION_OPEN - MASTER_GRIPPER_POSITION_CLOSE)
PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN = lambda x: x / (PUPPET_GRIPPER_POSITION_OPEN - PUPPET_GRIPPER_POSITION_CLOSE)
MASTER_POS2JOINT = (
lambda x: MASTER_GRIPPER_POSITION_NORMALIZE_FN(x) * (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
+ MASTER_GRIPPER_JOINT_CLOSE
)
MASTER_JOINT2POS = lambda x: MASTER_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - MASTER_GRIPPER_JOINT_CLOSE) / (MASTER_GRIPPER_JOINT_OPEN - MASTER_GRIPPER_JOINT_CLOSE)
)
PUPPET_POS2JOINT = (
lambda x: PUPPET_GRIPPER_POSITION_NORMALIZE_FN(x) * (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
+ PUPPET_GRIPPER_JOINT_CLOSE
)
PUPPET_JOINT2POS = lambda x: PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(
(x - PUPPET_GRIPPER_JOINT_CLOSE) / (PUPPET_GRIPPER_JOINT_OPEN - PUPPET_GRIPPER_JOINT_CLOSE)
)
MASTER_GRIPPER_JOINT_MID = (MASTER_GRIPPER_JOINT_OPEN + MASTER_GRIPPER_JOINT_CLOSE) / 2

View File

@@ -0,0 +1,272 @@
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
cameras = [
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=50,
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_low",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = True,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)

View File

@@ -0,0 +1,57 @@
from typing import List, Optional # noqa: UP035
import einops
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
from examples.aloha_real import real_env as _real_env
class AlohaRealEnvironment(_environment.Environment):
"""An environment for an Aloha robot on real hardware."""
def __init__(
self,
reset_position: Optional[List[float]] = None, # noqa: UP006,UP007
render_height: int = 224,
render_width: int = 224,
) -> None:
self._env = _real_env.make_real_env(init_node=True, reset_position=reset_position)
self._render_height = render_height
self._render_width = render_width
self._ts = None
@override
def reset(self) -> None:
self._ts = self._env.reset()
@override
def is_episode_complete(self) -> bool:
return False
@override
def get_observation(self) -> dict:
if self._ts is None:
raise RuntimeError("Timestep is not set. Call reset() first.")
obs = self._ts.observation
for k in list(obs["images"].keys()):
if "_depth" in k:
del obs["images"][k]
for cam_name in obs["images"]:
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(obs["images"][cam_name], self._render_height, self._render_width)
)
obs["images"][cam_name] = einops.rearrange(img, "h w c -> c h w")
return {
"state": obs["qpos"],
"images": obs["images"],
}
@override
def apply_action(self, action: dict) -> None:
self._ts = self._env.step(action["actions"])

View File

@@ -0,0 +1,51 @@
import dataclasses
import logging
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import tyro
from examples.aloha_real import env as _env
@dataclasses.dataclass
class Args:
host: str = "0.0.0.0"
port: int = 8000
action_horizon: int = 25
num_episodes: int = 1
max_episode_steps: int = 1000
def main(args: Args) -> None:
ws_client_policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {ws_client_policy.get_server_metadata()}")
metadata = ws_client_policy.get_server_metadata()
runtime = _runtime.Runtime(
environment=_env.AlohaRealEnvironment(reset_position=metadata.get("reset_pose")),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=ws_client_policy,
action_horizon=args.action_horizon,
)
),
subscribers=[],
max_hz=50,
num_episodes=args.num_episodes,
max_episode_steps=args.max_episode_steps,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,176 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
import collections
import time
from typing import Optional, List
import dm_env
from interbotix_xs_modules.arm import InterbotixManipulatorXS
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
from examples.aloha_real import constants
from examples.aloha_real import robot_utils
# This is the reset position that is used by the standard Aloha runtime.
DEFAULT_RESET_POSITION = [0, -0.96, 1.16, 0, -0.3, 0]
class RealEnv:
"""
Environment for real robot bi-manual manipulation
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"cam_high": (480x640x3), # h, w, c, dtype='uint8'
"cam_low": (480x640x3), # h, w, c, dtype='uint8'
"cam_left_wrist": (480x640x3), # h, w, c, dtype='uint8'
"cam_right_wrist": (480x640x3)} # h, w, c, dtype='uint8'
"""
def __init__(self, init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True):
# reset_position = START_ARM_POSE[:6]
self._reset_position = reset_position[:6] if reset_position else DEFAULT_RESET_POSITION
self.puppet_bot_left = InterbotixManipulatorXS(
robot_model="vx300s",
group_name="arm",
gripper_name="gripper",
robot_name="puppet_left",
init_node=init_node,
)
self.puppet_bot_right = InterbotixManipulatorXS(
robot_model="vx300s", group_name="arm", gripper_name="gripper", robot_name="puppet_right", init_node=False
)
if setup_robots:
self.setup_robots()
self.recorder_left = robot_utils.Recorder("left", init_node=False)
self.recorder_right = robot_utils.Recorder("right", init_node=False)
self.image_recorder = robot_utils.ImageRecorder(init_node=False)
self.gripper_command = JointSingleCommand(name="gripper")
def setup_robots(self):
robot_utils.setup_puppet_bot(self.puppet_bot_left)
robot_utils.setup_puppet_bot(self.puppet_bot_right)
def get_qpos(self):
left_qpos_raw = self.recorder_left.qpos
right_qpos_raw = self.recorder_right.qpos
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[7])
] # this is position not joint
right_gripper_qpos = [
constants.PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[7])
] # this is position not joint
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
def get_qvel(self):
left_qvel_raw = self.recorder_left.qvel
right_qvel_raw = self.recorder_right.qvel
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[7])]
right_gripper_qvel = [constants.PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[7])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
def get_effort(self):
left_effort_raw = self.recorder_left.effort
right_effort_raw = self.recorder_right.effort
left_robot_effort = left_effort_raw[:7]
right_robot_effort = right_effort_raw[:7]
return np.concatenate([left_robot_effort, right_robot_effort])
def get_images(self):
return self.image_recorder.get_images()
def set_gripper_pose(self, left_gripper_desired_pos_normalized, right_gripper_desired_pos_normalized):
left_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(left_gripper_desired_pos_normalized)
self.gripper_command.cmd = left_gripper_desired_joint
self.puppet_bot_left.gripper.core.pub_single.publish(self.gripper_command)
right_gripper_desired_joint = constants.PUPPET_GRIPPER_JOINT_UNNORMALIZE_FN(
right_gripper_desired_pos_normalized
)
self.gripper_command.cmd = right_gripper_desired_joint
self.puppet_bot_right.gripper.core.pub_single.publish(self.gripper_command)
def _reset_joints(self):
robot_utils.move_arms(
[self.puppet_bot_left, self.puppet_bot_right], [self._reset_position, self._reset_position], move_time=1
)
def _reset_gripper(self):
"""Set to position mode and do position resets: first close then open. Then change back to PWM mode
NOTE: This diverges from the original Aloha code which first opens then closes the gripper. Pi internal aloha data
was collected with the gripper starting in the open position. Leaving the grippers fully closed was also found to
increase the frequency of motor faults.
"""
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_CLOSE] * 2, move_time=1
)
robot_utils.move_grippers(
[self.puppet_bot_left, self.puppet_bot_right], [constants.PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5
)
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qpos()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def get_reward(self):
return 0
def reset(self, *, fake=False):
if not fake:
# Reboot puppet robot gripper motors
self.puppet_bot_left.dxl.robot_reboot_motors("single", "gripper", True)
self.puppet_bot_right.dxl.robot_reboot_motors("single", "gripper", True)
self._reset_joints()
self._reset_gripper()
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def step(self, action):
state_len = int(len(action) / 2)
left_action = action[:state_len]
right_action = action[state_len:]
self.puppet_bot_left.arm.set_joint_positions(left_action[:6], blocking=False)
self.puppet_bot_right.arm.set_joint_positions(right_action[:6], blocking=False)
self.set_gripper_pose(left_action[-1], right_action[-1])
time.sleep(constants.DT)
return dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=self.get_reward(), discount=None, observation=self.get_observation()
)
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14) # 6 joint + 1 gripper, for two arms
# Arm actions
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7 : 7 + 6] = master_bot_right.dxl.joint_states.position[:6]
# Gripper actions
action[6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_left.dxl.joint_states.position[6])
action[7 + 6] = constants.MASTER_GRIPPER_JOINT_NORMALIZE_FN(master_bot_right.dxl.joint_states.position[6])
return action
def make_real_env(init_node, *, reset_position: Optional[List[float]] = None, setup_robots: bool = True) -> RealEnv:
return RealEnv(init_node, reset_position=reset_position, setup_robots=setup_robots)

View File

@@ -0,0 +1,18 @@
Pillow
dm_control
einops
h5py
matplotlib
modern_robotics
msgpack
numpy>=1.22.4,<2.0.0
opencv-python
packaging
pexpect
pyquaternion
pyrealsense2
pyyaml
requests
rospkg
tyro
websockets

View File

@@ -0,0 +1,156 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_real/requirements.in -o examples/aloha_real/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
catkin-pkg==1.0.0
# via rospkg
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
contourpy==1.1.1
# via matplotlib
cycler==0.12.1
# via matplotlib
distro==1.9.0
# via rospkg
dm-control==1.0.23
# via -r examples/aloha_real/requirements.in
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
docutils==0.20.1
# via catkin-pkg
einops==0.8.0
# via -r examples/aloha_real/requirements.in
etils==1.3.0
# via mujoco
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
h5py==3.11.0
# via -r examples/aloha_real/requirements.in
idna==3.10
# via requests
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.7.5
# via -r examples/aloha_real/requirements.in
mdurl==0.1.2
# via markdown-it-py
modern-robotics==1.1.1
# via -r examples/aloha_real/requirements.in
msgpack==1.1.0
# via -r examples/aloha_real/requirements.in
mujoco==3.2.3
# via dm-control
numpy==1.24.4
# via
# -r examples/aloha_real/requirements.in
# contourpy
# dm-control
# dm-env
# h5py
# labmaze
# matplotlib
# modern-robotics
# mujoco
# opencv-python
# pyquaternion
# scipy
opencv-python==4.10.0.84
# via -r examples/aloha_real/requirements.in
packaging==24.2
# via
# -r examples/aloha_real/requirements.in
# matplotlib
pexpect==4.9.0
# via -r examples/aloha_real/requirements.in
pillow==10.4.0
# via
# -r examples/aloha_real/requirements.in
# matplotlib
protobuf==5.29.1
# via dm-control
ptyprocess==0.7.0
# via pexpect
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.1.4
# via
# catkin-pkg
# dm-control
# matplotlib
pyquaternion==0.9.9
# via -r examples/aloha_real/requirements.in
pyrealsense2==2.55.1.6486
# via -r examples/aloha_real/requirements.in
python-dateutil==2.9.0.post0
# via
# catkin-pkg
# matplotlib
pyyaml==6.0.2
# via
# -r examples/aloha_real/requirements.in
# rospkg
requests==2.32.3
# via
# -r examples/aloha_real/requirements.in
# dm-control
rich==13.9.4
# via tyro
rospkg==1.5.1
# via -r examples/aloha_real/requirements.in
scipy==1.10.1
# via dm-control
setuptools==75.3.0
# via
# catkin-pkg
# dm-control
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_real/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_real/requirements.in
zipp==3.20.2
# via etils

View File

@@ -0,0 +1,275 @@
# Ignore lint errors because this file is mostly copied from ACT (https://github.com/tonyzhaozh/act).
# ruff: noqa
from collections import deque
import datetime
import json
import time
from aloha.msg import RGBGrayscaleImage
from cv_bridge import CvBridge
from interbotix_xs_msgs.msg import JointGroupCommand
from interbotix_xs_msgs.msg import JointSingleCommand
import numpy as np
import rospy
from sensor_msgs.msg import JointState
from examples.aloha_real import constants
class ImageRecorder:
def __init__(self, init_node=True, is_debug=False):
self.is_debug = is_debug
self.bridge = CvBridge()
self.camera_names = ["cam_high", "cam_low", "cam_left_wrist", "cam_right_wrist"]
if init_node:
rospy.init_node("image_recorder", anonymous=True)
for cam_name in self.camera_names:
setattr(self, f"{cam_name}_rgb_image", None)
setattr(self, f"{cam_name}_depth_image", None)
setattr(self, f"{cam_name}_timestamp", 0.0)
if cam_name == "cam_high":
callback_func = self.image_cb_cam_high
elif cam_name == "cam_low":
callback_func = self.image_cb_cam_low
elif cam_name == "cam_left_wrist":
callback_func = self.image_cb_cam_left_wrist
elif cam_name == "cam_right_wrist":
callback_func = self.image_cb_cam_right_wrist
else:
raise NotImplementedError
rospy.Subscriber(f"/{cam_name}", RGBGrayscaleImage, callback_func)
if self.is_debug:
setattr(self, f"{cam_name}_timestamps", deque(maxlen=50))
self.cam_last_timestamps = {cam_name: 0.0 for cam_name in self.camera_names}
time.sleep(0.5)
def image_cb(self, cam_name, data):
setattr(
self,
f"{cam_name}_rgb_image",
self.bridge.imgmsg_to_cv2(data.images[0], desired_encoding="bgr8"),
)
# setattr(
# self,
# f"{cam_name}_depth_image",
# self.bridge.imgmsg_to_cv2(data.images[1], desired_encoding="mono16"),
# )
setattr(
self,
f"{cam_name}_timestamp",
data.header.stamp.secs + data.header.stamp.nsecs * 1e-9,
)
# setattr(self, f'{cam_name}_secs', data.images[0].header.stamp.secs)
# setattr(self, f'{cam_name}_nsecs', data.images[0].header.stamp.nsecs)
# cv2.imwrite('/home/lucyshi/Desktop/sample.jpg', cv_image)
if self.is_debug:
getattr(self, f"{cam_name}_timestamps").append(
data.images[0].header.stamp.secs + data.images[0].header.stamp.nsecs * 1e-9
)
def image_cb_cam_high(self, data):
cam_name = "cam_high"
return self.image_cb(cam_name, data)
def image_cb_cam_low(self, data):
cam_name = "cam_low"
return self.image_cb(cam_name, data)
def image_cb_cam_left_wrist(self, data):
cam_name = "cam_left_wrist"
return self.image_cb(cam_name, data)
def image_cb_cam_right_wrist(self, data):
cam_name = "cam_right_wrist"
return self.image_cb(cam_name, data)
def get_images(self):
image_dict = {}
for cam_name in self.camera_names:
while getattr(self, f"{cam_name}_timestamp") <= self.cam_last_timestamps[cam_name]:
time.sleep(0.00001)
rgb_image = getattr(self, f"{cam_name}_rgb_image")
depth_image = getattr(self, f"{cam_name}_depth_image")
self.cam_last_timestamps[cam_name] = getattr(self, f"{cam_name}_timestamp")
image_dict[cam_name] = rgb_image
image_dict[f"{cam_name}_depth"] = depth_image
return image_dict
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
for cam_name in self.camera_names:
image_freq = 1 / dt_helper(getattr(self, f"{cam_name}_timestamps"))
print(f"{cam_name} {image_freq=:.2f}")
print()
class Recorder:
def __init__(self, side, init_node=True, is_debug=False):
self.secs = None
self.nsecs = None
self.qpos = None
self.effort = None
self.arm_command = None
self.gripper_command = None
self.is_debug = is_debug
if init_node:
rospy.init_node("recorder", anonymous=True)
rospy.Subscriber(f"/puppet_{side}/joint_states", JointState, self.puppet_state_cb)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_group",
JointGroupCommand,
self.puppet_arm_commands_cb,
)
rospy.Subscriber(
f"/puppet_{side}/commands/joint_single",
JointSingleCommand,
self.puppet_gripper_commands_cb,
)
if self.is_debug:
self.joint_timestamps = deque(maxlen=50)
self.arm_command_timestamps = deque(maxlen=50)
self.gripper_command_timestamps = deque(maxlen=50)
time.sleep(0.1)
def puppet_state_cb(self, data):
self.qpos = data.position
self.qvel = data.velocity
self.effort = data.effort
self.data = data
if self.is_debug:
self.joint_timestamps.append(time.time())
def puppet_arm_commands_cb(self, data):
self.arm_command = data.cmd
if self.is_debug:
self.arm_command_timestamps.append(time.time())
def puppet_gripper_commands_cb(self, data):
self.gripper_command = data.cmd
if self.is_debug:
self.gripper_command_timestamps.append(time.time())
def print_diagnostics(self):
def dt_helper(l):
l = np.array(l)
diff = l[1:] - l[:-1]
return np.mean(diff)
joint_freq = 1 / dt_helper(self.joint_timestamps)
arm_command_freq = 1 / dt_helper(self.arm_command_timestamps)
gripper_command_freq = 1 / dt_helper(self.gripper_command_timestamps)
print(f"{joint_freq=:.2f}\n{arm_command_freq=:.2f}\n{gripper_command_freq=:.2f}\n")
def get_arm_joint_positions(bot):
return bot.arm.core.joint_states.position[:6]
def get_arm_gripper_positions(bot):
return bot.gripper.core.joint_states.position[6]
def move_arms(bot_list, target_pose_list, move_time=1):
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_joint_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
for t in range(num_steps):
for bot_id, bot in enumerate(bot_list):
bot.arm.set_joint_positions(traj_list[bot_id][t], blocking=False)
time.sleep(constants.DT)
def move_grippers(bot_list, target_pose_list, move_time):
print(f"Moving grippers to {target_pose_list=}")
gripper_command = JointSingleCommand(name="gripper")
num_steps = int(move_time / constants.DT)
curr_pose_list = [get_arm_gripper_positions(bot) for bot in bot_list]
traj_list = [
np.linspace(curr_pose, target_pose, num_steps)
for curr_pose, target_pose in zip(curr_pose_list, target_pose_list)
]
with open(f"/data/gripper_traj_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl", "a") as f:
for t in range(num_steps):
d = {}
for bot_id, bot in enumerate(bot_list):
gripper_command.cmd = traj_list[bot_id][t]
bot.gripper.core.pub_single.publish(gripper_command)
d[bot_id] = {"obs": get_arm_gripper_positions(bot), "act": traj_list[bot_id][t]}
f.write(json.dumps(d) + "\n")
time.sleep(constants.DT)
def setup_puppet_bot(bot):
bot.dxl.robot_reboot_motors("single", "gripper", True)
bot.dxl.robot_set_operating_modes("group", "arm", "position")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_on(bot)
def setup_master_bot(bot):
bot.dxl.robot_set_operating_modes("group", "arm", "pwm")
bot.dxl.robot_set_operating_modes("single", "gripper", "current_based_position")
torque_off(bot)
def set_standard_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 800)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def set_low_pid_gains(bot):
bot.dxl.robot_set_motor_registers("group", "arm", "Position_P_Gain", 100)
bot.dxl.robot_set_motor_registers("group", "arm", "Position_I_Gain", 0)
def torque_off(bot):
bot.dxl.robot_torque_enable("group", "arm", False)
bot.dxl.robot_torque_enable("single", "gripper", False)
def torque_on(bot):
bot.dxl.robot_torque_enable("group", "arm", True)
bot.dxl.robot_torque_enable("single", "gripper", True)
# for DAgger
def sync_puppet_to_master(master_bot_left, master_bot_right, puppet_bot_left, puppet_bot_right):
print("\nSyncing!")
# activate master arms
torque_on(master_bot_left)
torque_on(master_bot_right)
# get puppet arm positions
puppet_left_qpos = get_arm_joint_positions(puppet_bot_left)
puppet_right_qpos = get_arm_joint_positions(puppet_bot_right)
# get puppet gripper positions
puppet_left_gripper = get_arm_gripper_positions(puppet_bot_left)
puppet_right_gripper = get_arm_gripper_positions(puppet_bot_right)
# move master arms to puppet positions
move_arms(
[master_bot_left, master_bot_right],
[puppet_left_qpos, puppet_right_qpos],
move_time=1,
)
# move master grippers to puppet positions
move_grippers(
[master_bot_left, master_bot_right],
[puppet_left_gripper, puppet_right_gripper],
move_time=1,
)

View File

@@ -0,0 +1,36 @@
import matplotlib.pyplot as plt
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoDisplay(_subscriber.Subscriber):
"""Displays video frames."""
def __init__(self) -> None:
self._ax: plt.Axes | None = None
self._plt_img: plt.Image | None = None
@override
def on_episode_start(self) -> None:
plt.ion()
self._ax = plt.subplot()
self._plt_img = None
@override
def on_step(self, observation: dict, action: dict) -> None:
assert self._ax is not None
im = observation["image"][0] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
if self._plt_img is None:
self._plt_img = self._ax.imshow(im)
else:
self._plt_img.set_data(im)
plt.pause(0.001)
@override
def on_episode_end(self) -> None:
plt.ioff()
plt.close()

View File

@@ -0,0 +1,41 @@
# Dockerfile for the Aloha simulation environment.
# Build the container:
# docker build . -t aloha_sim -f examples/aloha_sim/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app aloha_sim /bin/bash
FROM python:3.11-slim@sha256:370c586a6ffc8c619e6d652f81c094b34b14b8f2fb9251f092de23f16e299b78
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev
ENV MUJOCO_GL=egl
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/aloha_sim/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/aloha_sim/main.py"]

View File

@@ -0,0 +1,36 @@
# Run Aloha Sim
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/aloha_sim/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.10 examples/aloha_sim/.venv
source examples/aloha_sim/.venv/bin/activate
uv pip sync examples/aloha_sim/requirements.txt
uv pip install -e packages/openpi-client
# Run the simulation
MUJOCO_GL=egl python examples/aloha_sim/main.py
```
Note: If you are seeing EGL errors, you may need to install the following dependencies:
```bash
sudo apt-get install -y libegl1-mesa-dev libgles2-mesa-dev
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env ALOHA_SIM
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/aloha_sim/compose.yml up --build
services:
runtime:
image: aloha_sim
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/aloha_sim/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,56 @@
import gym_aloha # noqa: F401
import gymnasium
import numpy as np
from openpi_client import image_tools
from openpi_client.runtime import environment as _environment
from typing_extensions import override
class AlohaSimEnvironment(_environment.Environment):
"""An environment for an Aloha robot in simulation."""
def __init__(self, task: str, obs_type: str = "pixels_agent_pos", seed: int = 0) -> None:
np.random.seed(seed)
self._rng = np.random.default_rng(seed)
self._gym = gymnasium.make(task, obs_type=obs_type)
self._last_obs = None
self._done = True
self._episode_reward = 0.0
@override
def reset(self) -> None:
gym_obs, _ = self._gym.reset(seed=int(self._rng.integers(2**32 - 1)))
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = False
self._episode_reward = 0.0
@override
def is_episode_complete(self) -> bool:
return self._done
@override
def get_observation(self) -> dict:
if self._last_obs is None:
raise RuntimeError("Observation is not set. Call reset() first.")
return self._last_obs # type: ignore
@override
def apply_action(self, action: dict) -> None:
gym_obs, reward, terminated, truncated, info = self._gym.step(action["actions"])
self._last_obs = self._convert_observation(gym_obs) # type: ignore
self._done = terminated or truncated
self._episode_reward = max(self._episode_reward, reward)
def _convert_observation(self, gym_obs: dict) -> dict:
img = gym_obs["pixels"]["top"]
img = image_tools.convert_to_uint8(image_tools.resize_with_pad(img, 224, 224))
# Convert axis order from [H, W, C] --> [C, H, W]
img = np.transpose(img, (2, 0, 1))
return {
"state": gym_obs["agent_pos"],
"images": {"cam_high": img},
}

View File

@@ -0,0 +1,55 @@
import dataclasses
import logging
import pathlib
import env as _env
from openpi_client import action_chunk_broker
from openpi_client import websocket_client_policy as _websocket_client_policy
from openpi_client.runtime import runtime as _runtime
from openpi_client.runtime.agents import policy_agent as _policy_agent
import saver as _saver
import tyro
@dataclasses.dataclass
class Args:
out_dir: pathlib.Path = pathlib.Path("data/aloha_sim/videos")
task: str = "gym_aloha/AlohaTransferCube-v0"
seed: int = 0
action_horizon: int = 10
host: str = "0.0.0.0"
port: int = 8000
display: bool = False
def main(args: Args) -> None:
runtime = _runtime.Runtime(
environment=_env.AlohaSimEnvironment(
task=args.task,
seed=args.seed,
),
agent=_policy_agent.PolicyAgent(
policy=action_chunk_broker.ActionChunkBroker(
policy=_websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
),
action_horizon=args.action_horizon,
)
),
subscribers=[
_saver.VideoSaver(args.out_dir),
],
max_hz=50,
)
runtime.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
tyro.cli(main)

View File

@@ -0,0 +1,8 @@
gym-aloha
imageio
matplotlib
msgpack
numpy>=1.22.4,<2.0.0
typing-extensions
tyro
websockets

View File

@@ -0,0 +1,132 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/aloha_sim/requirements.in -o examples/aloha_sim/requirements.txt --python-version 3.10
absl-py==2.1.0
# via
# dm-control
# dm-env
# labmaze
# mujoco
certifi==2024.8.30
# via requests
charset-normalizer==3.4.0
# via requests
cloudpickle==3.1.0
# via gymnasium
contourpy==1.3.1
# via matplotlib
cycler==0.12.1
# via matplotlib
dm-control==1.0.14
# via gym-aloha
dm-env==1.6
# via dm-control
dm-tree==0.1.8
# via
# dm-control
# dm-env
docstring-parser==0.16
# via tyro
farama-notifications==0.0.4
# via gymnasium
fonttools==4.55.2
# via matplotlib
glfw==2.8.0
# via
# dm-control
# mujoco
gym-aloha==0.1.1
# via -r examples/aloha_sim/requirements.in
gymnasium==1.0.0
# via gym-aloha
idna==3.10
# via requests
imageio==2.36.1
# via
# -r examples/aloha_sim/requirements.in
# gym-aloha
imageio-ffmpeg==0.5.1
# via imageio
kiwisolver==1.4.7
# via matplotlib
labmaze==1.0.6
# via dm-control
lxml==5.3.0
# via dm-control
markdown-it-py==3.0.0
# via rich
matplotlib==3.9.3
# via -r examples/aloha_sim/requirements.in
mdurl==0.1.2
# via markdown-it-py
msgpack==1.1.0
# via -r examples/aloha_sim/requirements.in
mujoco==2.3.7
# via
# dm-control
# gym-aloha
numpy==1.26.4
# via
# -r examples/aloha_sim/requirements.in
# contourpy
# dm-control
# dm-env
# gymnasium
# imageio
# labmaze
# matplotlib
# mujoco
# scipy
packaging==24.2
# via matplotlib
pillow==11.0.0
# via
# imageio
# matplotlib
protobuf==5.29.1
# via dm-control
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pyopengl==3.1.7
# via
# dm-control
# mujoco
pyparsing==3.2.0
# via
# dm-control
# matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
requests==2.32.3
# via dm-control
rich==13.9.4
# via tyro
scipy==1.14.1
# via dm-control
setuptools==75.6.0
# via
# dm-control
# imageio-ffmpeg
# labmaze
shtab==1.7.1
# via tyro
six==1.17.0
# via python-dateutil
tqdm==4.67.1
# via dm-control
typeguard==4.4.1
# via tyro
typing-extensions==4.12.2
# via
# -r examples/aloha_sim/requirements.in
# gymnasium
# rich
# typeguard
# tyro
tyro==0.9.2
# via -r examples/aloha_sim/requirements.in
urllib3==2.2.3
# via requests
websockets==14.1
# via -r examples/aloha_sim/requirements.in

View File

@@ -0,0 +1,40 @@
import logging
import pathlib
import imageio
import numpy as np
from openpi_client.runtime import subscriber as _subscriber
from typing_extensions import override
class VideoSaver(_subscriber.Subscriber):
"""Saves episode data."""
def __init__(self, out_dir: pathlib.Path, subsample: int = 1) -> None:
out_dir.mkdir(parents=True, exist_ok=True)
self._out_dir = out_dir
self._images: list[np.ndarray] = []
self._subsample = subsample
@override
def on_episode_start(self) -> None:
self._images = []
@override
def on_step(self, observation: dict, action: dict) -> None:
im = observation["images"]["cam_high"] # [C, H, W]
im = np.transpose(im, (1, 2, 0)) # [H, W, C]
self._images.append(im)
@override
def on_episode_end(self) -> None:
existing = list(self._out_dir.glob("out_[0-9]*.mp4"))
next_idx = max([int(p.stem.split("_")[1]) for p in existing], default=-1) + 1
out_path = self._out_dir / f"out_{next_idx}.mp4"
logging.info(f"Saving video to {out_path}")
imageio.mimwrite(
out_path,
[np.asarray(x) for x in self._images[:: self._subsample]],
fps=50 // max(1, self._subsample),
)

View File

@@ -0,0 +1,212 @@
from collections import deque
from typing import List, Dict, Optional, Any, Sequence, Deque, Union
import datasets
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def check_final(
last_states: Union[Deque[Sequence[float]], Sequence[Sequence[float]], torch.Tensor],
*,
# 索引与初始状态
arm_dofs: int = 6, # 左臂关节数(这里按你给的 6
gripper_index: int = -1, # 夹爪在向量中的索引(默认最后一维)
mean_initial_arm_state: Optional[Sequence[float]] = (0.0107, 0.0527, 0.0463, -0.0415, 0.0187, 0.0108),
mean_initial_gripper_state: float = 4.8438, # 目前不参与判定,保留以便后续扩展
# 判定阈值(角度阈值用“度”直观易调,内部会转换为弧度)
stability_window: int = 5, # 最近多少帧用于判“没有太大变化”
per_joint_range_deg: float = 2.0, # 窗口内每个关节的最大幅度max-min阈值
mean_speed_deg: float = 0.5, # 邻帧关节差的平均 L2每步阈值度/步)
min_change_from_initial_deg: float = 15.0, # 末帧相对初始的“至少变化量”L2
gripper_closed_thresh: float = 0.8, # 夹爪关闭阈值(数值越小说明越闭合)
) -> bool:
"""
返回 True 表示“到位”:(1) 最近窗口内姿态变化不大 & (2) 夹爪关闭 & (3) 末帧与初始相差足够大。
所有角度的阈值以“度”给出,这里会自动转弧度再比较。
"""
# --- 数据整理为 (N, D) tensor ---
if isinstance(last_states, torch.Tensor):
states = last_states
else:
states = torch.as_tensor(list(last_states), dtype=torch.float32)
if states.ndim != 2:
raise ValueError(f"last_states should be 2D, got shape {tuple(states.shape)}")
N, D = states.shape
if D < arm_dofs:
raise ValueError(f"Expected at least {arm_dofs} dims for arm + gripper, got {D}")
if N < 2:
return False # 样本太少,无法判定稳定
# 取最近窗口
w = min(N, stability_window)
window = states[-w:] # (w, D)
arm = window[:, :arm_dofs] # (w, 6)
last_arm = arm[-1] # (6,)
last_gripper = float(window[-1, gripper_index])
# --- 1) 最近 w 帧“没有太大变化” ---
# 两个指标每关节rangemax-min要小、相邻帧的平均“速度”要小
deg2rad = torch.pi / 180.0
range_tol = per_joint_range_deg * deg2rad
speed_tol = mean_speed_deg * deg2rad
ranges = arm.max(dim=0).values - arm.min(dim=0).values # (6,)
max_range = float(ranges.abs().max()) # 标量
diffs = arm[1:] - arm[:-1] # (w-1, 6)
mean_speed = float(torch.linalg.norm(diffs, dim=1).mean()) # 每步的平均 L2
stable = (max_range <= range_tol) and (mean_speed <= speed_tol)
# --- 2) 夹爪关闭 ---
gripper_closed = (last_gripper < gripper_closed_thresh)
# --- 3) 末帧与“初始”差距要大 ---
init = torch.as_tensor(mean_initial_arm_state, dtype=last_arm.dtype, device=last_arm.device)
if init.numel() != arm_dofs:
raise ValueError(f"mean_initial_arm_state length {init.numel()} != arm_dofs {arm_dofs}")
dist_from_init = float(torch.linalg.norm(last_arm - init))
far_from_init = (dist_from_init >= (min_change_from_initial_deg * deg2rad))
# 组合判定
return bool(stable and gripper_closed and far_from_init)
# return bool(gripper_closed and far_from_init)
def get_last_frames(ds: LeRobotDataset, include_images: bool = False, keys=None):
"""
Quickly fetch the last frame of each episode in a LeRobotDataset.
- include_images=False: Return only scalar/vector fields from parquet (faster, no video decoding).
- include_images=True : Additionally decode the corresponding image/video frame for the last frame.
- keys: Limit the set of columns to retrieve (default: all non-image/video fields + timestamp, etc.).
Returns: list[dict], where each element contains the last frame info of one episode.
"""
# 1) Compute the global index of the last row for each episode.
# ds.episode_data_index['to'] is the exclusive end index, so last frame = to - 1.
end_idxs = (ds.episode_data_index["to"] - 1).tolist()
# 2) Determine which columns to load.
# By default, exclude video/image columns to avoid triggering slow video decoding.
if keys is None:
non_media_keys = [k for k, ft in ds.features.items() if ft["dtype"] not in ("image", "video")]
keys = list(set(non_media_keys + ["timestamp", "episode_index", "task_index"]))
# 3) Select all last-frame rows at once (does not call __getitem__, so no video decoding is triggered).
last_rows = ds.hf_dataset.select(end_idxs)
# 4) Build a dictionary of tensors for each requested key.
out = []
col = {k: last_rows[k] for k in keys}
# Convert lists of tensors into stacked tensors for easier indexing.
for k, v in col.items():
# datasets.arrow_dataset.Column is the HuggingFace internal type for columns.
if isinstance(v, datasets.arrow_dataset.Column) and len(v) > 0 and hasattr(v[0], "shape"):
col[k] = torch.stack(v[:])
# Iterate through each episodes last frame and build a dict with its values.
for i, ep_end in enumerate(end_idxs):
item = {}
for k in keys:
val = col[k][i]
# Unpack 0-dimensional tensors into Python scalars.
if torch.is_tensor(val) and val.ndim == 0:
val = val.item()
item[k] = val
# Map task_index back to the human-readable task string.
if "task_index" in item:
item["task"] = ds.meta.tasks[int(item["task_index"])]
out.append(item)
# 5) Optionally decode the actual image/video frame for each last timestamp.
if include_images and len(ds.meta.video_keys) > 0:
for i, ep_end in enumerate(end_idxs):
ep_idx = int(out[i]["episode_index"])
ts = float(out[i]["timestamp"])
# Prepare a query dictionary: one timestamp per camera key.
query_ts = {k: [ts] for k in ds.meta.video_keys}
# Decode video frames at the specified timestamps for this episode.
frames = ds._query_videos(query_ts, ep_idx)
# Attach the decoded frame tensors to the output dictionary.
for k, v in frames.items():
out[i][k] = v
return out
if __name__ == "__main__":
# Initialize your dataset (replace with your repo ID or local path).
ds = LeRobotDataset(repo_id="arx_lift2/pick_parcel_20250915")
# Retrieve metadata only (timestamps, states, actions, tasks) without decoding video.
last_infos = get_last_frames(ds, include_images=False)
# Stack all 'observation.state' vectors into a single tensor for further processing.
states = torch.stack([info['observation.state'] for info in last_infos])
# Extract the left-arm joint states (first 7 values of each state vector).
left_arm_states = states[:, 0:7]
mean_state = torch.mean(left_arm_states, dim=0)
std_state = torch.std(left_arm_states, dim=0)
# Print the collected metadata for verification.
print(last_infos)
# --- Run check_final per episode using the last <=50 states ---
EP_ARM_DOFS = 6 # number of left-arm joints we use in check_final
GRIPPER_COL_FULL = -1 # gripper is the last element in the full state vector
STABILITY_WINDOW = 120 # must be consistent with check_final's default
# Determine which episodes to iterate
episode_indices = ds.episodes if ds.episodes is not None else sorted(ds.meta.episodes.keys())
episode_flags = {}
num_true, num_false = 0, 0
for ep_idx in episode_indices:
# Global index range [from_idx, to_idx) for this episode
from_idx = int(ds.episode_data_index["from"][ep_idx])
to_idx = int(ds.episode_data_index["to"][ep_idx])
if to_idx - from_idx <= 0:
episode_flags[ep_idx] = False
num_false += 1
continue
# Take the last <= STABILITY_WINDOW frames from this episode
idxs = list(range(max(from_idx, to_idx - STABILITY_WINDOW), to_idx))
rows = ds.hf_dataset.select(idxs)
# Collect full "observation.state" (shape ~ [W, S])
s_col = rows["observation.state"]
if isinstance(s_col, datasets.arrow_dataset.Column):
S = torch.stack(s_col[:]) # Column -> list[tensor] -> stack
else:
S = torch.stack(s_col) # already a list[tensor]
# Build the 7D small state per frame: first 6 joints + gripper
# (Assumes the gripper signal is at the last position of the full state vector)
small_states = torch.cat([S[:, :EP_ARM_DOFS], S[:, EP_ARM_DOFS:EP_ARM_DOFS+1]], dim=1)
# Run your stopping logic
ok = check_final(
small_states,
arm_dofs=EP_ARM_DOFS,
gripper_index=-1,
stability_window=STABILITY_WINDOW,
)
episode_flags[ep_idx] = bool(ok)
num_true += int(ok)
num_false += int(not ok)
# Summary
total_eps = len(episode_indices)
print(f"[check_final] passed: {num_true} / {total_eps} ({(num_true/max(total_eps,1)):.1%})")
# List some failed episodes for quick inspection
failed_eps = [e for e, passed in episode_flags.items() if not passed]
print("Failed episode indices (first 20):", failed_eps[:20])

View File

@@ -0,0 +1,88 @@
import os
import cv2
from pathlib import Path
from tqdm import tqdm
def extract_last_frame_from_videos(root_dir, output_dir, xx_last_frame=1):
"""
遍历目录找到所有images.rgb.hand_right视频文件提取最后一帧并保存
"""
# 查找所有mp4视频文件
video_files = []
for root, dirs, files in os.walk(root_dir):
for file in files:
if file.endswith('.mp4') and 'observation/head' in root:
video_files.append(os.path.join(root, file))
print(f"找到 {len(video_files)} 个视频文件")
# 处理每个视频文件
for video_path in tqdm(video_files):
try:
# 提取set名称和episode名称
path_parts = Path(video_path).parts
set_name = None
episode_name = None
for part in path_parts:
if part.startswith('set'):
set_name = part
if part.startswith('000'):
episode_name = part.replace('.mp4', '')
if not set_name or not episode_name:
print(f"无法从路径中提取set和episode信息: {video_path}")
continue
# 生成输出文件名
output_filename = f"{set_name}_{episode_name}.jpg"
output_path = os.path.join(output_dir, output_filename)
# 打开视频文件
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
print(f"无法打开视频: {video_path}")
continue
# 获取总帧数
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
print(f"视频没有帧: {video_path}")
cap.release()
continue
# 跳转到最后一帧
cap.set(cv2.CAP_PROP_POS_FRAMES, total_frames - xx_last_frame)
ret, frame = cap.read()
if ret:
# 保存最后一帧
cv2.imwrite(output_path, frame)
print(f"已保存:\n {output_path}")
else:
print(f"无法读取最后一帧: {video_path}")
# 释放资源
cap.release()
except Exception as e:
print(f"处理视频时出错 {video_path}: {str(e)}")
if __name__ == "__main__":
# 指定要遍历的根目录
root_directory = "/home/caijunhao/h-ceph/InternData-A1-raw/arx_lift2/Pick_the_industrial_components_from_the_conveyor" # 当前目录,您可以修改为您的目录路径
output_path = 'data/Pick_the_industrial_components_from_the_conveyor/'
os.makedirs(output_path, exist_ok=True)
sub_list = os.listdir(root_directory)
exclude_list = []
# exclude_list = [f"{i}" for i in range(16)] + [f"{i}" for i in range(26, 29)]
xx_last_frame = 1
# import pdb
# pdb.set_trace()
for sub in tqdm(sub_list):
if sub.split('-')[1].split('_')[0] in exclude_list:
continue
# print("os.path.join([root_directory, sub])\n", os.path.join(root_directory, sub))
extract_last_frame_from_videos(os.path.join(root_directory, sub), output_path, xx_last_frame=xx_last_frame)
print("处理完成!")

View File

@@ -0,0 +1,670 @@
# source /fs-computility/efm/liyang/miniconda3/etc/profile.d/conda.sh
# conda activate act
import argparse
import json
import logging
import os
import gc
import shutil
from concurrent.futures import ALL_COMPLETED, ProcessPoolExecutor, ThreadPoolExecutor, as_completed, wait
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple
import torchvision
import cv2
import h5py
import lmdb
import numpy as np
import pickle
import torch
from PIL import Image
from scipy.spatial.transform import Rotation
from tqdm import tqdm
import logging
import pdb
import os
import imageio # imageio-ffmpeg
from lerobot.common.datasets.compute_stats import auto_downsample_height_width, get_feature_stats, sample_indices
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import check_timestamps_sync, get_episode_data_index, validate_episode_buffer
import time
# import ray
# from ray.runtime_env import RuntimeEnv
"""
Store both camera image and robot state as a combined observation.
Args:
observation: images(camera), states (robot state)
actions: joint, gripper, ee_pose
"""
FEATURES = {
"images.rgb.head": {
"dtype": "video",
"shape": (368, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_left": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
"images.rgb.hand_right": {
"dtype": "video",
"shape": (480, 640, 3),
"names": ["height", "width", "channel"],
},
# "states.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "states.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "states.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "states.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0",],
# },
"observation.state": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
"action": {
"dtype": "float32",
"shape": (14,),
"names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5", "left_gripper_0",
"right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5","right_gripper_0"],
},
# "actions.left_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["left_joint_0", "left_joint_1", "left_joint_2", "left_joint_3", "left_joint_4", "left_joint_5",],
# },
# "actions.left_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["left_gripper_0",],
# },
# "actions.right_joint.position": {
# "dtype": "float32",
# "shape": (6,),
# "names": ["right_joint_0", "right_joint_1", "right_joint_2", "right_joint_3", "right_joint_4", "right_joint_5",],
# },
# "actions.right_gripper.position": {
# "dtype": "float32",
# "shape": (1,),
# "names": ["right_gripper_0", ],
# },
}
import numpy as np
def filter_forbidden_frames(state_dict, position_threshold=0.001, velocity_threshold=0.005):
"""
过滤禁止的帧,基于位置和速度阈值
参数:
- state_dict: 形状为 (n, 14) 的状态数组
- position_threshold: 位置变化的阈值
- velocity_threshold: 速度变化的阈值
返回:
- valid_mask: 布尔数组True表示有效帧
"""
# 排除夹爪列第6和第13列索引从0开始
qpos_columns = [i for i in range(14)]
qpos_data = state_dict[:, qpos_columns]
n_frames = len(state_dict)
valid_mask = np.ones(n_frames, dtype=bool)
# import pdb
# pdb.set_trace()
# 计算帧间差异(速度)
if n_frames > 1:
diff_sum = np.sum(np.abs(np.diff(qpos_data, axis=0)), axis=1)
# sorted_indices = np.argsort(diff_sum)[::-1]
# sorted_abs_sums = diff_sum[sorted_indices]
# velocities = np.diff(qpos_data, axis=0)
# 检查速度是否超过阈值
for i in range(n_frames - 1):
if np.any(np.abs(diff_sum[i]) > position_threshold):
valid_mask[i] = True # 有运动,有效帧
else:
valid_mask[i] = False # 静止,可能是禁止帧
valid_mask[i] = True
return valid_mask
def statistical_filter(state_dict, std_multiplier=2.0):
"""
使用统计方法检测异常(禁止)帧
"""
# 排除夹爪列
qpos_columns = [i for i in range(14) if i not in [6, 13]]
qpos_data = state_dict[:, qpos_columns]
# 计算每列的均值和标准差
means = np.mean(qpos_data, axis=0)
stds = np.std(qpos_data, axis=0)
# 创建有效掩码
valid_mask = np.ones(len(state_dict), dtype=bool)
for i in range(len(state_dict)):
# 检查每个关节位置是否在合理范围内
deviations = np.abs(qpos_data[i] - means)
if np.any(deviations > std_multiplier * stds):
valid_mask[i] = False # 异常帧
return valid_mask
class ARXDataset(LeRobotDataset):
def __init__(
self,
repo_id: str,
root: str | Path | None = None,
episodes: list[int] | None = None,
image_transforms: Callable | None = None,
delta_timestamps: dict[list[float]] | None = None,
tolerance_s: float = 1e-4,
download_videos: bool = True,
local_files_only: bool = False,
video_backend: str | None = None,
):
super().__init__(
repo_id=repo_id,
root=root,
episodes=episodes,
image_transforms=image_transforms,
delta_timestamps=delta_timestamps,
tolerance_s=tolerance_s,
download_videos=download_videos,
local_files_only=local_files_only,
video_backend=video_backend,
)
def save_episode(self, episode_data: dict | None = None, videos: dict | None = None) -> None:
if not episode_data:
episode_buffer = self.episode_buffer
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
episode_length = episode_buffer.pop("size")
tasks = episode_buffer.pop("task")
episode_tasks = list(set(tasks))
episode_index = episode_buffer["episode_index"]
episode_buffer["index"] = np.arange(self.meta.total_frames, self.meta.total_frames + episode_length)
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
for task in episode_tasks:
task_index = self.meta.get_task_index(task)
if task_index is None:
self.meta.add_task(task)
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
for key, ft in self.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["video"]:
continue
episode_buffer[key] = np.stack(episode_buffer[key]).squeeze()
for key in self.meta.video_keys:
video_path = self.root / self.meta.get_video_file_path(episode_index, key)
episode_buffer[key] = str(video_path) # PosixPath -> str
video_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copyfile(videos[key], video_path)
ep_stats = compute_episode_stats(episode_buffer, self.features)
self._save_episode_table(episode_buffer, episode_index)
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
check_timestamps_sync(
episode_buffer["timestamp"],
episode_buffer["episode_index"],
ep_data_index_np,
self.fps,
self.tolerance_s,
)
if not episode_data:
self.episode_buffer = self.create_episode_buffer()
def add_frame(self, frame: dict) -> None:
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
features = {key: value for key, value in self.features.items() if key in self.hf_features}
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
for key in frame:
if key == "task":
self.episode_buffer["task"].append(frame["task"])
continue
if key not in self.features:
print("key ", key)
raise ValueError(f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'.")
# import pdb
# pdb.set_trace()
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
# def crop_resize_no_padding(image, target_size=(480, 640)):
# """
# Crop and scale to target size (no padding)
# :param image: input image (NumPy array)
# :param target_size: target size (height, width)
# :return: processed image
# """
# h, w = image.shape[:2]
# target_h, target_w = target_size
# target_ratio = target_w / target_h # Target aspect ratio (e.g. 640/480=1.333)
# # the original image aspect ratio and cropping direction
# if w / h > target_ratio: # Original image is wider → crop width
# crop_w = int(h * target_ratio) # Calculate crop width based on target aspect ratio
# crop_h = h
# start_x = (w - crop_w) // 2 # Horizontal center starting point
# start_y = 0
# else: # Original image is higher → crop height
# crop_h = int(w / target_ratio) # Calculate clipping height according to target aspect ratio
# crop_w = w
# start_x = 0
# start_y = (h - crop_h) // 2 # Vertical center starting point
# # Perform centered cropping (to prevent out-of-bounds)
# start_x, start_y = max(0, start_x), max(0, start_y)
# end_x, end_y = min(w, start_x + crop_w), min(h, start_y + crop_h)
# cropped = image[start_y:end_y, start_x:end_x]
# # Resize to target size (bilinear interpolation)
# resized = cv2.resize(cropped, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
# return resized
def load_lmdb_data(episode_path: Path, sava_path: Path, fps_factor: int, target_fps: int) -> Optional[Dict]:
def load_image(txn, key):
raw = txn.get(key)
data = pickle.loads(raw)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
# Convert to RGB if necessary
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# image = crop_resize_no_padding(image, target_size=(480, 640))
return image
try:
env = lmdb.open(
str(episode_path / "lmdb"),
readonly=True,
lock=False,
max_readers=128,
readahead=False
)
with env.begin(write=False) as txn:
keys = [k for k, _ in txn.cursor()]
image_keys = sorted([k for k in keys if b'head' in k])
if not image_keys:
return None
all_qpos = pickle.loads(txn.get(b'/observations/qpos'))
if np.isscalar(all_qpos):
total_steps = len(image_keys)
all_qpos = [all_qpos] * total_steps
else:
total_steps = len(all_qpos)
all_qpos = np.stack(all_qpos)
state_action_dict = {}
state_action_dict["states.left_joint.position"] = all_qpos[:, :6]
state_action_dict["states.left_gripper.position"] = all_qpos[:, 6][:, None] # np.expand_dims(all_qpos[:, 6], axis=1)
state_action_dict["states.right_joint.position"] = all_qpos[:, 7:13]
state_action_dict["states.right_gripper.position"] = all_qpos[:, 13][:, None] # np.expand_dims(all_qpos[:, 13], axis=1)
# state_keys = list(state_action_dict.keys())
# for k in state_keys:
# state_action_dict[k.replace("states", "actions")] = np.concatenate([state_action_dict[k][1:, :], state_action_dict[k][-1, :][None,:]], axis=0)
# action_dict = {}
# action_dict["actions.left_joint.position"] = np.concatenate([state_dict["states.left_joint.position"][1:, :], state_dict["states.left_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_joint.position"] = state_dict["states.right_joint.position"][1:, :]
# action_dict["actions.right_gripper.position"] = state_dict["states.right_gripper.position"][1:, :]
action_dict = {}
action_dict["action"] = np.concatenate([all_qpos[1:,], all_qpos[-1,].reshape(-1, 14)], axis=0)
state_dict = {}
state_dict["observation.state"] = all_qpos
mask1 = filter_forbidden_frames(state_dict["observation.state"])
# state_dict["observation.state"] = state_dict["observation.state"][mask1]
# action_dict["actions.left_gripper.position"] = state_dict["states.left_gripper.position"][1:, :]
# action_dict["actions.right_arm.position"] = np.concatenate([state_action_dict["states.right_joint.position"][1:, :], state_action_dict["states.right_joint.position"][-1, :][None,:]], axis=0)
# action_dict["actions.left_arm.position"] = state_dict["states.right_gripper.position"][1:, :]
assert total_steps == len(image_keys), "qpos length mismatch"
selected_steps = [step for step in range(total_steps) if step % fps_factor == 0 and mask1[step]]
frames = []
image_observations = {
"images.rgb.head": [],
"images.rgb.hand_left": [],
"images.rgb.hand_right": []
}
start_time = time.time()
for step_index, step in enumerate(selected_steps):
step_str = f"{step:04d}"
head_key = f"observation/head/color_image/{step_str}".encode()
left_key = f"observation/left_wrist/color_image/{step_str}".encode()
right_key = f"observation/right_wrist/color_image/{step_str}".encode()
if not (head_key in keys and left_key in keys and right_key in keys):
continue
# state = all_qpos[step]
# if step_index < len(selected_steps) - 1:
# action = all_qpos[selected_steps[step_index + 1]]
# else:
# action = state
data_dict = {}
# for key, value in state_action_dict.items():
# data_dict[key] = value[step]
data_dict['action'] = action_dict["action"][step]
data_dict["task"] = " ".join(episode_path.parent.parent.name.split("_"))
data_dict['observation.state'] = state_dict["observation.state"][step]
# frames.append({
# "observation.states.joint.position": state,
# "actions.joint.position": action,
# "task": task_name,
# })
frames.append(data_dict)
image_observations["images.rgb.head"].append(load_image(txn, head_key))
image_observations["images.rgb.hand_left"].append(load_image(txn, left_key))
image_observations["images.rgb.hand_right"].append(load_image(txn, right_key))
end_time = time.time()
elapsed_time = end_time - start_time
print(f"load image_observations of {episode_path}")
env.close()
if not frames:
return None
os.makedirs(sava_path, exist_ok=True)
os.makedirs(sava_path/episode_path.name, exist_ok=True)
imageio.mimsave(sava_path/episode_path.name/'head.mp4', image_observations["images.rgb.head"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_left.mp4', image_observations["images.rgb.hand_left"], fps=target_fps)
imageio.mimsave(sava_path/episode_path.name/'hand_right.mp4', image_observations["images.rgb.hand_right"], fps=target_fps)
print(f"imageio.mimsave time taken of {episode_path}")
return {
"frames": frames,
"videos": {
"images.rgb.head": sava_path/episode_path.name/"head.mp4",
"images.rgb.hand_left": sava_path/episode_path.name/"hand_left.mp4",
"images.rgb.hand_right": sava_path/episode_path.name/"hand_right.mp4",
},
}
except Exception as e:
logging.error(f"Failed to load LMDB data: {e}")
return None
def get_all_tasks(src_path: Path, output_path: Path) -> Tuple[Path, Path]:
src_dirs = sorted(list(src_path.glob("*"))) # "set*-*_collector*_datatime" as the conversion unit
save_dirs = [output_path/_dir.parent.name/_dir.name for _dir in src_dirs]
tasks_tuples = zip(src_dirs, save_dirs)
for task in tasks_tuples:
yield task
def compute_episode_stats(episode_data: Dict[str, List[str] | np.ndarray], features: Dict) -> Dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
}
return ep_stats
def sample_images(input):
if type(input) is str:
video_path = input
reader = torchvision.io.VideoReader(video_path, stream="video")
frames = [frame["data"] for frame in reader]
frames_array = torch.stack(frames).numpy() # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
elif type(input) is np.ndarray:
frames_array = input[:, None, :, :] # Shape: [T, C, H, W]
sampled_indices = sample_indices(len(frames_array))
images = None
for i, idx in enumerate(sampled_indices):
img = frames_array[idx]
img = auto_downsample_height_width(img)
if images is None:
images = np.empty((len(sampled_indices), *img.shape), dtype=np.uint8)
images[i] = img
return images
def load_local_dataset(episode_path: str, save_path:str, origin_fps=30, target_fps=30):
fps_factor = origin_fps // target_fps
# print(f"fps downsample factor: {fps_factor}")
# logging.info(f"fps downsample factor: {fps_factor}")
# for format_str in [f"{episode_id:07d}", f"{episode_id:06d}", str(episode_id)]:
# episode_path = Path(src_path) / format_str
# save_path = Path(save_path) / format_str
# if episode_path.exists():
# break
# else:
# logging.warning(f"Episode directory not found for ID {episode_id}")
# return None, None
episode_path = Path(episode_path)
if not episode_path.exists():
logging.warning(f"{episode_path} does not exist")
return None, None
if not (episode_path / "lmdb/data.mdb").exists():
logging.warning(f"LMDB data not found for episode {episode_path}")
return None, None
raw_dataset = load_lmdb_data(episode_path, save_path, fps_factor, target_fps)
if raw_dataset is None:
return None, None
frames = raw_dataset["frames"] # states, actions, task
videos = raw_dataset["videos"] # image paths
## check the frames
for camera_name, video_path in videos.items():
if not os.path.exists(video_path):
logging.error(f"Video file {video_path} does not exist.")
print(f"Camera {camera_name} Video file {video_path} does not exist.")
return None, None
return frames, videos
def save_as_lerobot_dataset(task: tuple[Path, Path], repo_id, num_threads, debug, origin_fps=30, target_fps=30, robot_type="piper", delete_downsampled_videos=True):
src_path, save_path = task
print(f"**Processing collected** {src_path}")
print(f"**saving to** {save_path}")
if save_path.exists():
# print(f"Output directory {save_path} already exists. Deleting it.")
# logging.warning(f"Output directory {save_path} already exists. Deleting it.")
# shutil.rmtree(save_path)
print(f"Output directory {save_path} already exists.")
return
dataset = ARXDataset.create(
repo_id=f"{repo_id}",
root=save_path,
fps=target_fps,
robot_type=robot_type,
features=FEATURES,
)
all_episode_paths = sorted([f.as_posix() for f in src_path.glob(f"*") if f.is_dir()])
# all_subdir_eids = [int(Path(path).name) for path in all_subdir]
if debug:
for i in range(1):
# pdb.set_trace()
frames, videos = load_local_dataset(episode_path=all_episode_paths[i], save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
if delete_downsampled_videos:
for _, video_path in videos.items():
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
# os.remove(video_path)
# print(f"Successfully deleted: {parent_dir}")
print(f"Successfully deleted: {video_path}")
except Exception as e:
pass # Handle the case where the directory might not exist or is already deleted
else:
for batch_index in range(len(all_episode_paths)//num_threads+1):
batch_episode_paths = all_episode_paths[batch_index*num_threads:(batch_index+1)*num_threads]
if len(batch_episode_paths) == 0:
continue
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for episode_path in batch_episode_paths:
print("starting to process episode: ", episode_path)
futures.append(
executor.submit(load_local_dataset, episode_path=episode_path, save_path=save_path, origin_fps=origin_fps, target_fps=target_fps)
)
for raw_dataset in as_completed(futures):
frames, videos = raw_dataset.result()
if frames is None or videos is None:
print(f"Skipping episode {episode_path} due to missing data.")
continue
for frame_data in frames:
dataset.add_frame(frame_data)
dataset.save_episode(videos=videos)
gc.collect()
print(f"finishing processed {videos}")
if delete_downsampled_videos:
for _, video_path in videos.items():
# Get the parent directory of the video
parent_dir = os.path.dirname(video_path)
try:
shutil.rmtree(parent_dir)
print(f"Successfully deleted: {parent_dir}")
except Exception as e:
pass
def main(src_path, save_path, repo_id, num_threads=60, debug=False, origin_fps=30, target_fps=30):
logging.info("Scanning for episodes...")
tasks = get_all_tasks(src_path, save_path)
# import pdb
# pdb.set_trace()
if debug:
task = next(tasks)
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
else:
for task in tasks:
save_as_lerobot_dataset(task, repo_id, num_threads=num_threads, debug=debug, origin_fps=origin_fps, target_fps=target_fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert collected data from Piper to Lerobot format.")
parser.add_argument(
"--src_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/",
help="Path to the input file containing collected data in Piper format.",
#help="/fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Make_a_beef_sandwich",
)
parser.add_argument(
"--save_path",
type=str,
# required=False,
default="/fs-computility/efm/shared/datasets/myData-A1/real/lerobot_v2_1/agilex_split_aloha/",
help="Path to the output file where the converted Lerobot format will be saved.",
#help="Path to the output file where the converted Lerobot format will be saved.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Run in debug mode with limited episodes",
)
parser.add_argument(
"--num-threads",
type=int,
default=50,
help="Number of threads per process",
)
# parser.add_argument(
# "--task_name",
# type=str,
# required=True,
# default="Pick_up_the_marker_and_put_it_into_the_pen_holder",
# help="Name of the task to be processed. Default is 'Pick_up_the_marker_and_put_it_into_the_pen_holder'.",
# )
parser.add_argument(
"--repo_id",
type=str,
required=True,
# default="SplitAloha_20250714",
help="identifier for the dataset repository.",
)
parser.add_argument(
"--origin_fps",
type=int,
default=30,
help="Frames per second for the obervation video. Default is 30.",
)
parser.add_argument(
"--target_fps",
type=int,
default=30,
help="Frames per second for the downsample video. Default is 30.",
)
args = parser.parse_args()
assert int(args.origin_fps) % int(args.target_fps) == 0, "origin_fps must be an integer multiple of target_fps"
start_time = time.time()
main(
src_path=Path(args.src_path),
save_path=Path(args.save_path),
repo_id=args.repo_id,
num_threads=args.num_threads,
debug=args.debug,
origin_fps=args.origin_fps,
target_fps=args.target_fps
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Total time taken: {elapsed_time:.2f} seconds")
# --target_fps 10
# --src_path /fs-computility/efm/shared/datasets/myData-A1/real/raw_data/agilex_split_aloha/Put_the_bananas_in_the_basket
# --save_path /mnt/shared-storage-user/internvla/Users/liyang/data/processed_data/arx_lift2

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,587 @@
#!/usr/bin/env python3
"""
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
This script loads a JAX model checkpoint using orbax and can either:
1. Print out all the parameter keys in a hierarchical structure for inspection
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
Usage:
# Just inspect keys:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
# Convert to PyTorch:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
Example:
# pi0_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
# pi0_aloha_sim
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
# pi05_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
"""
import json
import os
import pathlib
import shutil
from typing import Literal
from flax.nnx import traversals
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
import tyro
import openpi.models.gemma
import openpi.models.model
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
from openpi.training import utils
import openpi.training.config as _config
def slice_paligemma_state_dict(state_dict, config):
"""Convert PaliGemma JAX parameters to PyTorch format."""
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# patch embeddings
jax_key = f"img/embedding/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
jax_key = f"img/embedding/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# positional embeddings
jax_key = f"img/pos_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
)
encoderblock_attention_0_key_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
)
encoderblock_attention_0_value_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
)
encoderblock_attention_0_value_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
)
encoderblock_attention_0_query_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
)
encoderblock_attention_0_query_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
)
encoderblock_attention_0_out_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
)
encoderblock_attention_0_out_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
)
for i in range(config.vision_config.num_hidden_layers):
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
] = encoderblock_layernorm0_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
] = encoderblock_layernorm0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
] = encoderblock_layernorm1_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
] = encoderblock_layernorm1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
] = encoderblock_mlp_dense0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
] = encoderblock_mlp_dense1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# multimodal projector
jax_key = f"img/head/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/head/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# text decoder (gemma)
jax_key = f"llm/embedder/input_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# pop the einsum attention + mlp representations
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.transpose(2, 0, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
llm_mlp_linear[i].transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
] = llm_post_attention_layernorm[i]
jax_key = f"llm/final_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
expert_dict = {}
final_state_dict = {}
# Expert-related keys to extract (including pi05 Dense layer parameters)
expert_keys = [
f"llm/final_norm_1/scale{suffix}",
f"llm/final_norm_1/Dense_0/bias{suffix}",
f"llm/final_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
]
for key, value in state_dict.items():
if key not in expert_keys:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
"""Convert Gemma JAX parameters to PyTorch format."""
# Add missing attributes to config if they don't exist
if not hasattr(config, "vocab_size"):
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
if not hasattr(config, "hidden_size"):
config.hidden_size = config.width
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.depth
if not hasattr(config, "num_attention_heads"):
config.num_attention_heads = config.num_heads
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
llm_input_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
)
llm_post_attention_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
)
else:
# Regular pi0 with standard RMSNorm
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
.transpose(1, 0)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
i
].transpose()
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
llm_input_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
llm_post_attention_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
llm_input_layernorm_kernel[i].transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
llm_post_attention_layernorm_kernel[i].transpose()
)
else:
# Regular pi0 with standard RMSNorm
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
llm_post_attention_layernorm[i]
)
# Handle final norm layer
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
else:
# Regular pi0 with standard RMSNorm
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
f"llm/final_norm_{num_expert}/scale{suffix}"
)
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
"""Load and process params by restoring via JAX model loader first.
This respects dtype conversions that occur during model restore.
"""
# Use repository restore utility to load a pure dict of params (value suffix removed)
params = openpi.models.model.restore_params(
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
)
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
def load_jax_model_and_print_keys(checkpoint_dir: str):
"""
Load JAX model from checkpoint and print all parameter keys.
Args:
checkpoint_dir: Path to the checkpoint directory
"""
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
# Initialize checkpointer
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
print(utils.array_tree_to_info(metadata))
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
):
"""
Convert PI0 JAX checkpoint to PyTorch format.
Args:
checkpoint_dir: Path to the JAX checkpoint
precision: Model precision (float32, bfloat16, float16)
output_path: Path to save the converted PyTorch model
model_config: Model config
"""
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
print(f"Model config: {model_config}")
# Break down orbax ckpts by restoring via JAX to respect dtype
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
# Process projection params
if model_config.pi05:
keys = [
"action_in_proj",
"action_out_proj",
"time_mlp_in",
"time_mlp_out",
]
else:
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
pytorch_weight_key = f"{key}.weight"
pytorch_bias_key = f"{key}.bias"
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
# Create configs based on checkpoint path
# All models use the same PaliGemma config structure
class PaliGemmaConfig:
def __init__(self):
self.vision_config = type(
"obj",
(object,),
{
"hidden_size": 1152,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"intermediate_size": 4304,
"patch_size": 14,
"projection_dim": 2048,
},
)()
self.text_config = type(
"obj",
(object,),
{
"hidden_size": 2048,
"num_hidden_layers": 18,
"num_attention_heads": 8,
"head_dim": 256,
"intermediate_size": 16384,
},
)()
paligemma_config = PaliGemmaConfig()
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
# Process PaliGemma weights
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
# Process Gemma weights from expert_params
gemma_params = slice_gemma_state_dict(
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
)
# Instantiate model
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
# Combine all parameters (no prefix needed for our model structure)
all_params = {**paligemma_params, **gemma_params, **projection_params}
# Load state dict
pi0_model.load_state_dict(all_params, strict=False)
if precision == "float32":
pi0_model = pi0_model.to(torch.float32)
elif precision == "bfloat16":
pi0_model = pi0_model.to(torch.bfloat16)
else:
raise ValueError(f"Invalid precision: {precision}")
# Save the converted model using safetensors
os.makedirs(output_path, exist_ok=True)
# Save model weights as SafeTensors using save_model to handle tied weights
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
# Copy assets folder if it exists
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
if assets_source.exists():
assets_dest = pathlib.Path(output_path) / "assets"
if assets_dest.exists():
shutil.rmtree(assets_dest)
shutil.copytree(assets_source, assets_dest)
# Save config as JSON for reference
config_dict = {
"action_dim": model_config.action_dim,
"action_horizon": model_config.action_horizon,
"paligemma_variant": model_config.paligemma_variant,
"action_expert_variant": model_config.action_expert_variant,
"precision": precision,
}
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
print("Model conversion completed successfully!")
print(f"Model saved to {output_path}")
def main(
checkpoint_dir: str,
config_name: str,
output_path: str | None = None,
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
*,
inspect_only: bool = False,
):
"""Load JAX model and optionally convert to PyTorch.
Args:
checkpoint_dir: Path to the JAX checkpoint directory
output_path: Path to save converted PyTorch model (required for conversion)
precision: Precision for model conversion
inspect_only: Only inspect parameter keys, don't convert
"""
model_config = _config.get_config(config_name).model
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
raise ValueError(f"Config {config_name} is not a Pi0Config")
if inspect_only:
load_jax_model_and_print_keys(checkpoint_dir)
else:
if not output_path:
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
return
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,84 @@
# DROID Policies in openpi
We offer instructions for:
- [Running inference for our best $pi_{0.5}$-DROID policy](./README.md#running-droid-inference)
- [Running inference for other pre-trained DROID policies ($\pi_0$, $\pi_0$-FAST, ...)](./README.md#running-roboarena-baseline-policies)
- [Pre-training *generalist* policies on the *full* DROID dataset](./README_train.md#training-on-droid)
- [Fine-tuning expert $\pi_{0.5}$ on your custom DROID dataset](./README_train.md#fine-tuning-on-custom-droid-datasets)
## Running DROID Inference
This example shows how to run the fine-tuned $\pi_{0.5}$-DROID model on the [DROID robot platform](https://github.com/droid-dataset/droid). Based on the [public RoboArena benchmark](https://robo-arena.github.io/leaderboard), this is currently our strongest generalist DROID policy.
### Step 1: Start a policy server
Since the DROID control laptop does not have a powerful GPU, we will start a remote policy server on a different machine with a more powerful GPU and then query it from the DROID control laptop during inference.
1. On a machine with a powerful GPU (~NVIDIA 4090), clone and install the `openpi` repository following the instructions in the [README](https://github.com/Physical-Intelligence/openpi).
2. Start the OpenPI server via the following command:
```bash
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi05_droid --policy.dir=gs://openpi-assets/checkpoints/pi05_droid
```
You can also run the equivalent command below:
```bash
uv run scripts/serve_policy.py --env=DROID
```
### Step 2: Run the DROID robot
1. Make sure you have the most recent version of the DROID package installed on both the DROID control laptop and the NUC.
2. On the control laptop, activate your DROID conda environment.
3. Clone the openpi repo and install the openpi client, which we will use to connect to the policy server (this has very few dependencies and should be very fast to install): with the DROID conda environment activated, run `cd $OPENPI_ROOT/packages/openpi-client && pip install -e .`.
4. Install `tyro`, which we will use for command line parsing: `pip install tyro`.
5. Copy the `main.py` file from this directory to the `$DROID_ROOT/scripts` directory.
6. Replace the camera IDs in the `main.py` file with the IDs of your cameras (you can find the camera IDs by running `ZED_Explorer` in the command line, which will open a tool that shows you all connected cameras and their IDs -- you can also use it to make sure that the cameras are well-positioned to see the scene you want the robot to interact with).
7. Run the `main.py` file. Make sure to point the IP and host address to the policy server. (To make sure the server machine is reachable from the DROID laptop, you can run `ping <server_ip>` from the DROID laptop.) Also make sure to specify the external camera to use for the policy (we only input one external camera), choose from ["left", "right"].
```bash
python3 scripts/main.py --remote_host=<server_ip> --remote_port=<server_port> --external_camera="left"
```
The script will ask you to enter a free-form language instruction for the robot to follow. Make sure to point the cameras at the scene you want the robot to interact with. You _do not_ need to carefully control camera angle, object positions, etc. The policy is fairly robust in our experience. Happy prompting!
## Troubleshooting
| Issue | Solution |
|-------|----------|
| Cannot reach policy server | Make sure the server is running and the IP and port are correct. You can check that the server machine is reachable by running `ping <server_ip>` from the DROID laptop. |
| Cannot find cameras | Make sure the camera IDs are correct and that the cameras are connected to the DROID laptop. Sometimes replugging the cameras can help. You can check all connected cameras by running `ZED_Explore` in the command line. |
| Policy inference is slow / inconsistent | Try using a wired internet connection for the DROID laptop to reduce latency (0.5 - 1 sec latency per chunk is normal). |
| Policy does not perform the task well | In our experiments, the policy could perform simple table top manipulation tasks (pick-and-place) across a wide range of environments, camera positions, and lighting conditions. If the policy does not perform the task well, you can try modifying the scene or object placement to make the task easier. Also make sure that the camera view you are passing to the policy can see all relevant objects in the scene (the policy is only conditioned on a single external camera + wrist camera, make sure you are feeding the desired camera to the policy). Use `ZED_Explore` to check that the camera view you are passing to the policy can see all relevant objects in the scene. Finally, the policy is far from perfect and will fail on more complex manipulation tasks, but it usually makes a decent effort. :) |
## Running Other Policies
We provide configs for running the baseline DROID policies from the [RoboArena](https://robo-arena.github.io/) paper. Simply run the commands below to start inference servers for the respective policies. Then follow the instructions above to run evaluation on the DROID robot.
```
# Train from pi0-FAST, using FAST tokenizer
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_fast_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_fast_droid
# Train from pi0, using flow matching
uv run scripts/serve_policy.py policy:checkpoint --policy.config=pi0_droid --policy.dir=gs://openpi-assets/checkpoints/pi0_droid
# Trained from PaliGemma, using RT-2 / OpenVLA style binning tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_binning_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_binning_droid
# Trained from PaliGemma, using FAST tokenizer (using universal FAST+ tokenizer).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_droid
# Trained from PaliGemma, using FAST tokenizer (tokenizer trained on DROID dataset).
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_fast_specialist_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_fast_specialist_droid
# Trained from PaliGemma, using FSQ tokenizer.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_vq_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_vq_droid
# pi0-style diffusion / flow VLA, trained on DROID from PaliGemma.
uv run scripts/serve_policy.py policy:checkpoint --policy.config=paligemma_diffusion_droid --policy.dir=gs://openpi-assets/checkpoints/roboarena/paligemma_diffusion_droid
```
You can find the inference configs in [roboarena_config.py](../../src/openpi/training/misc/roboarena_config.py).

View File

@@ -0,0 +1,106 @@
# Training on DROID
Here we describe how to fine-tune the pi0.5 model on the *full* DROID dataset. This is an approximate open-source reproduction of the pi05-DROID training pipeline.
(small differences in data loading and the used action space) -- For a tutorial on how to fine-tune your model with a smaller, custom dataset collected on the DROID platform, see below.
In contrast to the rest of openpi, which uses LeRobot for data loading, we need to use RLDS as the data format for full DROID training (since at the moment LeRobot isn't scalable enough
for larger datasets like DROID -- they are working on improving it though). Below, we provide instructions for updating your openpi environment for RLDS data loading and where to download the DROID dataset.
## Install
We need a few additional dependencies for RLDS data loading. Run:
```bash
uv sync --group rlds
```
## Download DROID dataset
You can download the DROID dataset with the following command (after installing the `gsutil` google cloud CLI):
```
gsutil -m cp -r gs://gresearch/robotics/droid/1.0.1 <your_download_path>/droid/1.0.1
```
Note that downloading version 1.0.1 is important (not v1.0.0): it contains the complete set of language annotations (~75k episodes) while v1.0.0 only has annotations for 30k episodes. If for some reason you would like to use another version, modify the line `version="1.0.1"` in the `DroidRldsDataset` object [here](src/openpi/training/droid_rlds_dataset.py).
You will need 1.8TB of disk storage to download the DROID RLDS dataset.
## Run
First, change the `rlds_data_dir` path in your `TrainConfig` to the directory that you downloaded the `droid` dataset into (see [src/openpi/training/config.py](src/openpi/training/config.py)).
Then, compute normalization statistics (this will take ~10 minutes):
```bash
uv run --group rlds scripts/compute_norm_stats.py --config-name pi05_full_droid_finetune --max-frames 10_000_000
```
Run training:
```bash
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 uv run --group rlds scripts/train.py pi05_full_droid_finetune --exp-name=my_experiment --overwrite
```
**Note**: The original pi0.5-DROID model was trained with joint velocity actions.
Joint velocity actions are not compatible with simulated evaluation environments (much harder to simulate).
Thus, we do not recommend training with joint velocity actions and instead use joint position actions here.
## Compute Requirements
Our DROID training config requires approximately 2 days on 8x H100 GPUs for convergence (100k iterations, bs256, approx. 1 epoch).
If you start from PaliGemma instead of pi0 initialization, plan with ~5 days on 8x H100s (240k iterations, i.e. 3 epochs).
We have experimented with LoRA for cheaper finetuning, but haven't found the policies to perform well so far.
## Data Filtering
Like any diverse real-robot dataset, the DROID dataset isn't perfectly "clean" and we have found data filtering to significantly improve policy performance. Concretely, the DROID dataset contains many *idle* timesteps in which the robot does not move (in part due to the VR teleoperation interface that was used during data collection, we will not go into too much detail here). Appropriate filtering of these idle transitions can improve policy performance.
By default, our openpi training recipe implements the same idle filter used to train all pi-DROID models. We implement it by pre-computing which dataset indices to sample during training. You can check [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) for how we compute these indices. Roughly speaking, we filter any time steps for which the next chunk of actions would be largely idle. During training, our code automatically pulls our pre-computed list of indices from cloud storage and applies them. If you want to modify the idle filter / create your custom sampling logic, you can modify our script to generate a new index list and provide it via the `filter_dict_path="<path_to_filter_dict>"` argument in [src/openpi/training/config.py](src/openpi/training/config.py).
**Note**: our list of filtering indices is only valid for the `droid/1.0.1` dataset mentioned in the download section above, and will not provide valid filtering for any other version of the DROID dataset, so make sure you download the dataset above! If you have a custom DROID version, you can rerun the [compute_droid_nonidle_ranges.py](examples/droid/compute_droid_nonidle_ranges.py) script to generate a new list of sampling indices.
## RoboArena
Consider submitting your DROID policies to the [RoboArena benchmark](https://robo-arena.github.io/), which allows you to evaluate your policies on diverse tasks & scenes, **in the real world**! :)
If you have questions about RoboArena, please email [karl.pertsch@gmail.com](mailto:karl.pertsch@gmail.com).
# Fine-Tuning on Custom DROID Datasets
Here we describe how to fine-tune a model on a custom (smaller) dataset collected on the DROID platform. Like for other datasets, we will first convert the custom DROID dataset to LeRobot and then fine-tune a model (pi05-droid) on it.
Note: We use LeRobot here, since we assume the custom DROID fine-tuning dataset to be relatively small (<10s of hours). For larger datasets (like the full DROID dataset) we recommend using RLDS for it's better efficiency (see the example above).
## Step 1: Converting your custom DROID dataset to LeRobot
We will use a small subset of the real DROID dataset for this example. This is a subset of just 30 demonstrations -- we assume that you will use your own dataset instead, but here is the command to download our subset (1.6GB):
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/IRIS/success/2023-12-04 <your_target_path>
```
We will also download the language annotations for the DROID dataset so we can pair our demonstrations with language instructions. Again, for your own data you can manually enter your language instructions and don't need to download our annotations. To download the DROID language annotations (12MB), run:
```
gsutil -m cp -r gs://gresearch/robotics/droid_raw/1.0.1/aggregated-annotations-030724.json <your_target_dir>
```
For your own dataset, make sure that each episode's directory contains a folder called `recordings/MP4` -- if not, you need to first run the MP4 video extraction (from SVO files) using the script [here](https://github.com/droid-dataset/droid/blob/main/scripts/convert/svo_to_mp4.py).
Now, we will use the `convert_droid_to_lerobot.py` script to create a LeRobot version of this dataset (takes <5min for the 30 demonstrations):
```
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir <your_target_path>
```
## Step 2: Run fine-tuning with your custom dataset
Now we can run fine-tuning with our converted custom dataset. We provide an example config for fine-tuning `pi05_droid` on the custom dataset we created.
You can modify the config easily to work with other base models, or use your custom DROID dataset in `config.py` (seach for `pi05_droid_finetune`).
To launch training:
```
uv run scripts/train.py pi05_droid_finetune --exp-name=my_experiment --overwrite
```
Once trained, you can follow the instructions in [`examples/droid/README.md`](examples/droid/README.md) to serve the policy and run it on the robot.

View File

@@ -0,0 +1,103 @@
"""
Iterates through the DROID dataset and creates a json mapping from episode unique IDs to ranges of time steps
that should be sampled during training (all others are filtered out).
Filtering logic:
We look for ranges of consecutive steps that contain at most min_idle_len consecutive idle frames
(default to 7 -- as most DROID action-chunking policies run the first 8 actions generated in each chunk, filtering
this way means the policy will not get stuck outputting stationary actions). Additionally, we also only keep non-idle
ranges of length at least min_non_idle_len (default to 16 frames = ~1 second), while also removing the last
filter_last_n_in_ranges frames from the end of each range (as those all correspond to action chunks with many idle actions).
This leaves us with trajectory segments consisting of contiguous, significant movement. Training on this filtered set
yields policies that output fewer stationary actions (i.e., get "stuck" in states less).
"""
import json
import os
from pathlib import Path
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Set to the GPU you want to use, or leave empty for CPU
builder = tfds.builder_from_directory(
# path to the `droid` directory (not its parent)
builder_dir="<path_to_droid_dataset_tfds_files>",
)
ds = builder.as_dataset(split="train", shuffle_files=False)
tf.data.experimental.ignore_errors(ds)
keep_ranges_path = "<path_to_where_to_save_the_json>"
min_idle_len = 7 # If more than this number of consecutive idle frames, filter all of them out
min_non_idle_len = 16 # If fewer than this number of consecutive non-idle frames, filter all of them out
filter_last_n_in_ranges = 10 # When using a filter dict, remove this many frames from the end of each range
keep_ranges_map = {}
if Path(keep_ranges_path).exists():
with Path(keep_ranges_path).open("r") as f:
keep_ranges_map = json.load(f)
print(f"Resuming from {len(keep_ranges_map)} episodes already processed")
for ep_idx, ep in enumerate(tqdm(ds)):
recording_folderpath = ep["episode_metadata"]["recording_folderpath"].numpy().decode()
file_path = ep["episode_metadata"]["file_path"].numpy().decode()
key = f"{recording_folderpath}--{file_path}"
if key in keep_ranges_map:
continue
joint_velocities = [step["action_dict"]["joint_velocity"].numpy() for step in ep["steps"]]
joint_velocities = np.array(joint_velocities)
is_idle_array = np.hstack(
[np.array([False]), np.all(np.abs(joint_velocities[1:] - joint_velocities[:-1]) < 1e-3, axis=1)]
)
# Find what steps go from idle to non-idle and vice-versa
is_idle_padded = np.concatenate(
[[False], is_idle_array, [False]]
) # Start and end with False, so idle at first step is a start of motion
is_idle_diff = np.diff(is_idle_padded.astype(int))
is_idle_true_starts = np.where(is_idle_diff == 1)[0] # +1 transitions --> going from idle to non-idle
is_idle_true_ends = np.where(is_idle_diff == -1)[0] # -1 transitions --> going from non-idle to idle
# Find which steps correspond to idle segments of length at least min_idle_len
true_segment_masks = (is_idle_true_ends - is_idle_true_starts) >= min_idle_len
is_idle_true_starts = is_idle_true_starts[true_segment_masks]
is_idle_true_ends = is_idle_true_ends[true_segment_masks]
keep_mask = np.ones(len(joint_velocities), dtype=bool)
for start, end in zip(is_idle_true_starts, is_idle_true_ends, strict=True):
keep_mask[start:end] = False
# Get all non-idle ranges of at least 16
# Same logic as above, but for keep_mask, allowing us to filter out contiguous ranges of length < min_non_idle_len
keep_padded = np.concatenate([[False], keep_mask, [False]])
keep_diff = np.diff(keep_padded.astype(int))
keep_true_starts = np.where(keep_diff == 1)[0] # +1 transitions --> going from filter out to keep
keep_true_ends = np.where(keep_diff == -1)[0] # -1 transitions --> going from keep to filter out
# Find which steps correspond to non-idle segments of length at least min_non_idle_len
true_segment_masks = (keep_true_ends - keep_true_starts) >= min_non_idle_len
keep_true_starts = keep_true_starts[true_segment_masks]
keep_true_ends = keep_true_ends[true_segment_masks]
# Add mapping from episode unique ID key to list of non-idle ranges to keep
keep_ranges_map[key] = []
for start, end in zip(keep_true_starts, keep_true_ends, strict=True):
keep_ranges_map[key].append((int(start), int(end) - filter_last_n_in_ranges))
if ep_idx % 1000 == 0:
with Path(keep_ranges_path).open("w") as f:
json.dump(keep_ranges_map, f)
print("Done!")
with Path(keep_ranges_path).open("w") as f:
json.dump(keep_ranges_map, f)

View File

@@ -0,0 +1,477 @@
"""
Minimal example script for converting a dataset collected on the DROID platform to LeRobot format.
Usage:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/droid/convert_droid_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
The resulting dataset will get saved to the $LEROBOT_HOME directory.
"""
from collections import defaultdict
import copy
import glob
import json
from pathlib import Path
import shutil
import cv2
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
from PIL import Image
from tqdm import tqdm
import tyro
REPO_NAME = "your_hf_username/my_droid_dataset" # Name of the output dataset, also used for the Hugging Face Hub
def resize_image(image, size):
image = Image.fromarray(image)
return np.array(image.resize(size, resample=Image.BICUBIC))
def main(data_dir: str, *, push_to_hub: bool = False):
# Clean up any existing dataset in the output directory
output_path = HF_LEROBOT_HOME / REPO_NAME
if output_path.exists():
shutil.rmtree(output_path)
data_dir = Path(data_dir)
# Create LeRobot dataset, define features to store
# We will follow the DROID data naming conventions here.
# LeRobot assumes that dtype of image data is `image`
dataset = LeRobotDataset.create(
repo_id=REPO_NAME,
robot_type="panda",
fps=15, # DROID data is typically recorded at 15fps
features={
# We call this "left" since we will only use the left stereo camera (following DROID RLDS convention)
"exterior_image_1_left": {
"dtype": "image",
"shape": (180, 320, 3), # This is the resolution used in the DROID RLDS dataset
"names": ["height", "width", "channel"],
},
"exterior_image_2_left": {
"dtype": "image",
"shape": (180, 320, 3),
"names": ["height", "width", "channel"],
},
"wrist_image_left": {
"dtype": "image",
"shape": (180, 320, 3),
"names": ["height", "width", "channel"],
},
"joint_position": {
"dtype": "float32",
"shape": (7,),
"names": ["joint_position"],
},
"gripper_position": {
"dtype": "float32",
"shape": (1,),
"names": ["gripper_position"],
},
"actions": {
"dtype": "float32",
"shape": (8,), # We will use joint *velocity* actions here (7D) + gripper position (1D)
"names": ["actions"],
},
},
image_writer_threads=10,
image_writer_processes=5,
)
# Load language annotations
# Note: we load the DROID language annotations for this example, but you can manually define them for your own data
with (data_dir / "aggregated-annotations-030724.json").open() as f:
language_annotations = json.load(f)
# Loop over raw DROID fine-tuning datasets and write episodes to the LeRobot dataset
# We assume the following directory structure:
# RAW_DROID_PATH/
# - <...>/
# - recordings/
# - MP4/
# - <camera_id>.mp4 # single-view video of left stereo pair camera
# - trajectory.hdf5
# - <...>/
episode_paths = list(data_dir.glob("**/trajectory.h5"))
print(f"Found {len(episode_paths)} episodes for conversion")
# We will loop over each dataset_name and write episodes to the LeRobot dataset
for episode_path in tqdm(episode_paths, desc="Converting episodes"):
# Load raw data
recording_folderpath = episode_path.parent / "recordings" / "MP4"
trajectory = load_trajectory(str(episode_path), recording_folderpath=str(recording_folderpath))
# To load the language instruction, we need to parse out the episode_id from the metadata file
# Again, you can modify this step for your own data, to load your own language instructions
metadata_filepath = next(iter(episode_path.parent.glob("metadata_*.json")))
episode_id = metadata_filepath.name.split(".")[0].split("_")[-1]
language_instruction = language_annotations.get(episode_id, {"language_instruction1": "Do something"})[
"language_instruction1"
]
print(f"Converting episode with language instruction: {language_instruction}")
# Write to LeRobot dataset
for step in trajectory:
camera_type_dict = step["observation"]["camera_type"]
wrist_ids = [k for k, v in camera_type_dict.items() if v == 0]
exterior_ids = [k for k, v in camera_type_dict.items() if v != 0]
dataset.add_frame(
{
# Note: need to flip BGR --> RGB for loaded images
"exterior_image_1_left": resize_image(
step["observation"]["image"][exterior_ids[0]][..., ::-1], (320, 180)
),
"exterior_image_2_left": resize_image(
step["observation"]["image"][exterior_ids[1]][..., ::-1], (320, 180)
),
"wrist_image_left": resize_image(step["observation"]["image"][wrist_ids[0]][..., ::-1], (320, 180)),
"joint_position": np.asarray(
step["observation"]["robot_state"]["joint_positions"], dtype=np.float32
),
"gripper_position": np.asarray(
step["observation"]["robot_state"]["gripper_position"][None], dtype=np.float32
),
# Important: we use joint velocity actions here since pi05-droid was pre-trained on joint velocity actions
"actions": np.concatenate(
[step["action"]["joint_velocity"], step["action"]["gripper_position"][None]], dtype=np.float32
),
"task": language_instruction,
}
)
dataset.save_episode()
# Optionally push to the Hugging Face Hub
if push_to_hub:
dataset.push_to_hub(
tags=["libero", "panda", "rlds"],
private=False,
push_videos=True,
license="apache-2.0",
)
##########################################################################################################
################ The rest of this file are functions to parse the raw DROID data #########################
################ You don't need to worry about understanding this part #########################
################ It was copied from here: https://github.com/JonathanYang0127/r2d2_rlds_dataset_builder/blob/parallel_convert/r2_d2/r2_d2.py
##########################################################################################################
camera_type_dict = {
"hand_camera_id": 0,
"varied_camera_1_id": 1,
"varied_camera_2_id": 1,
}
camera_type_to_string_dict = {
0: "hand_camera",
1: "varied_camera",
2: "fixed_camera",
}
def get_camera_type(cam_id):
if cam_id not in camera_type_dict:
return None
type_int = camera_type_dict[cam_id]
return camera_type_to_string_dict[type_int]
class MP4Reader:
def __init__(self, filepath, serial_number):
# Save Parameters #
self.serial_number = serial_number
self._index = 0
# Open Video Reader #
self._mp4_reader = cv2.VideoCapture(filepath)
if not self._mp4_reader.isOpened():
raise RuntimeError("Corrupted MP4 File")
def set_reading_parameters(
self,
image=True, # noqa: FBT002
concatenate_images=False, # noqa: FBT002
resolution=(0, 0),
resize_func=None,
):
# Save Parameters #
self.image = image
self.concatenate_images = concatenate_images
self.resolution = resolution
self.resize_func = cv2.resize
self.skip_reading = not image
if self.skip_reading:
return
def get_frame_resolution(self):
width = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_WIDTH)
height = self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT)
return (width, height)
def get_frame_count(self):
if self.skip_reading:
return 0
return int(self._mp4_reader.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT))
def set_frame_index(self, index):
if self.skip_reading:
return
if index < self._index:
self._mp4_reader.set(cv2.CAP_PROP_POS_FRAMES, index - 1)
self._index = index
while self._index < index:
self.read_camera(ignore_data=True)
def _process_frame(self, frame):
frame = copy.deepcopy(frame)
if self.resolution == (0, 0):
return frame
return self.resize_func(frame, self.resolution)
def read_camera(self, ignore_data=False, correct_timestamp=None): # noqa: FBT002
# Skip if Read Unnecesary #
if self.skip_reading:
return {}
# Read Camera #
success, frame = self._mp4_reader.read()
self._index += 1
if not success:
return None
if ignore_data:
return None
# Return Data #
data_dict = {}
if self.concatenate_images or "stereo" not in self.serial_number:
data_dict["image"] = {self.serial_number: self._process_frame(frame)}
else:
single_width = frame.shape[1] // 2
data_dict["image"] = {
self.serial_number + "_left": self._process_frame(frame[:, :single_width, :]),
self.serial_number + "_right": self._process_frame(frame[:, single_width:, :]),
}
return data_dict
def disable_camera(self):
if hasattr(self, "_mp4_reader"):
self._mp4_reader.release()
class RecordedMultiCameraWrapper:
def __init__(self, recording_folderpath, camera_kwargs={}): # noqa: B006
# Save Camera Info #
self.camera_kwargs = camera_kwargs
# Open Camera Readers #
mp4_filepaths = glob.glob(recording_folderpath + "/*.mp4")
all_filepaths = mp4_filepaths
self.camera_dict = {}
for f in all_filepaths:
serial_number = f.split("/")[-1][:-4]
cam_type = get_camera_type(serial_number)
camera_kwargs.get(cam_type, {})
if f.endswith(".mp4"):
Reader = MP4Reader # noqa: N806
else:
raise ValueError
self.camera_dict[serial_number] = Reader(f, serial_number)
def read_cameras(self, index=None, camera_type_dict={}, timestamp_dict={}): # noqa: B006
full_obs_dict = defaultdict(dict)
# Read Cameras In Randomized Order #
all_cam_ids = list(self.camera_dict.keys())
# random.shuffle(all_cam_ids)
for cam_id in all_cam_ids:
if "stereo" in cam_id:
continue
try:
cam_type = camera_type_dict[cam_id]
except KeyError:
print(f"{self.camera_dict} -- {camera_type_dict}")
raise ValueError(f"Camera type {cam_id} not found in camera_type_dict") # noqa: B904
curr_cam_kwargs = self.camera_kwargs.get(cam_type, {})
self.camera_dict[cam_id].set_reading_parameters(**curr_cam_kwargs)
timestamp = timestamp_dict.get(cam_id + "_frame_received", None)
if index is not None:
self.camera_dict[cam_id].set_frame_index(index)
data_dict = self.camera_dict[cam_id].read_camera(correct_timestamp=timestamp)
# Process Returned Data #
if data_dict is None:
return None
for key in data_dict:
full_obs_dict[key].update(data_dict[key])
return full_obs_dict
def get_hdf5_length(hdf5_file, keys_to_ignore=[]): # noqa: B006
length = None
for key in hdf5_file:
if key in keys_to_ignore:
continue
curr_data = hdf5_file[key]
if isinstance(curr_data, h5py.Group):
curr_length = get_hdf5_length(curr_data, keys_to_ignore=keys_to_ignore)
elif isinstance(curr_data, h5py.Dataset):
curr_length = len(curr_data)
else:
raise ValueError
if length is None:
length = curr_length
assert curr_length == length
return length
def load_hdf5_to_dict(hdf5_file, index, keys_to_ignore=[]): # noqa: B006
data_dict = {}
for key in hdf5_file:
if key in keys_to_ignore:
continue
curr_data = hdf5_file[key]
if isinstance(curr_data, h5py.Group):
data_dict[key] = load_hdf5_to_dict(curr_data, index, keys_to_ignore=keys_to_ignore)
elif isinstance(curr_data, h5py.Dataset):
data_dict[key] = curr_data[index]
else:
raise ValueError
return data_dict
class TrajectoryReader:
def __init__(self, filepath, read_images=True): # noqa: FBT002
self._hdf5_file = h5py.File(filepath, "r")
is_video_folder = "observations/videos" in self._hdf5_file
self._read_images = read_images and is_video_folder
self._length = get_hdf5_length(self._hdf5_file)
self._video_readers = {}
self._index = 0
def length(self):
return self._length
def read_timestep(self, index=None, keys_to_ignore=[]): # noqa: B006
# Make Sure We Read Within Range #
if index is None:
index = self._index
else:
assert not self._read_images
self._index = index
assert index < self._length
# Load Low Dimensional Data #
keys_to_ignore = [*keys_to_ignore.copy(), "videos"]
timestep = load_hdf5_to_dict(self._hdf5_file, self._index, keys_to_ignore=keys_to_ignore)
# Increment Read Index #
self._index += 1
# Return Timestep #
return timestep
def close(self):
self._hdf5_file.close()
def load_trajectory(
filepath=None,
read_cameras=True, # noqa: FBT002
recording_folderpath=None,
camera_kwargs={}, # noqa: B006
remove_skipped_steps=False, # noqa: FBT002
num_samples_per_traj=None,
num_samples_per_traj_coeff=1.5,
):
read_recording_folderpath = read_cameras and (recording_folderpath is not None)
traj_reader = TrajectoryReader(filepath)
if read_recording_folderpath:
camera_reader = RecordedMultiCameraWrapper(recording_folderpath, camera_kwargs)
horizon = traj_reader.length()
timestep_list = []
# Choose Timesteps To Save #
if num_samples_per_traj:
num_to_save = num_samples_per_traj
if remove_skipped_steps:
num_to_save = int(num_to_save * num_samples_per_traj_coeff)
max_size = min(num_to_save, horizon)
indices_to_save = np.sort(np.random.choice(horizon, size=max_size, replace=False))
else:
indices_to_save = np.arange(horizon)
# Iterate Over Trajectory #
for i in indices_to_save:
# Get HDF5 Data #
timestep = traj_reader.read_timestep(index=i)
# If Applicable, Get Recorded Data #
if read_recording_folderpath:
timestamp_dict = timestep["observation"]["timestamp"]["cameras"]
camera_type_dict = {
k: camera_type_to_string_dict[v] for k, v in timestep["observation"]["camera_type"].items()
}
camera_obs = camera_reader.read_cameras(
index=i, camera_type_dict=camera_type_dict, timestamp_dict=timestamp_dict
)
camera_failed = camera_obs is None
# Add Data To Timestep If Successful #
if camera_failed:
break
timestep["observation"].update(camera_obs)
# Filter Steps #
step_skipped = not timestep["observation"]["controller_info"].get("movement_enabled", True)
delete_skipped_step = step_skipped and remove_skipped_steps
# Save Filtered Timesteps #
if delete_skipped_step:
del timestep
else:
timestep_list.append(timestep)
# Remove Extra Transitions #
timestep_list = np.array(timestep_list)
if (num_samples_per_traj is not None) and (len(timestep_list) > num_samples_per_traj):
ind_to_keep = np.random.choice(len(timestep_list), size=num_samples_per_traj, replace=False)
timestep_list = timestep_list[ind_to_keep]
# Close Readers #
traj_reader.close()
# Return Data #
return timestep_list
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,246 @@
# ruff: noqa
import contextlib
import dataclasses
import datetime
import faulthandler
import os
import signal
import time
from moviepy.editor import ImageSequenceClip
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy
import pandas as pd
from PIL import Image
from droid.robot_env import RobotEnv
import tqdm
import tyro
faulthandler.enable()
# DROID data collection frequency -- we slow down execution to match this frequency
DROID_CONTROL_FREQUENCY = 15
@dataclasses.dataclass
class Args:
# Hardware parameters
left_camera_id: str = "<your_camera_id>" # e.g., "24259877"
right_camera_id: str = "<your_camera_id>" # e.g., "24514023"
wrist_camera_id: str = "<your_camera_id>" # e.g., "13062452"
# Policy parameters
external_camera: str | None = (
None # which external camera should be fed to the policy, choose from ["left", "right"]
)
# Rollout parameters
max_timesteps: int = 600
# How many actions to execute from a predicted action chunk before querying policy server again
# 8 is usually a good default (equals 0.5 seconds of action execution).
open_loop_horizon: int = 8
# Remote server parameters
remote_host: str = "0.0.0.0" # point this to the IP address of the policy server, e.g., "192.168.1.100"
remote_port: int = (
8000 # point this to the port of the policy server, default server port for openpi servers is 8000
)
# We are using Ctrl+C to optionally terminate rollouts early -- however, if we press Ctrl+C while the policy server is
# waiting for a new action chunk, it will raise an exception and the server connection dies.
# This context manager temporarily prevents Ctrl+C and delays it after the server call is complete.
@contextlib.contextmanager
def prevent_keyboard_interrupt():
"""Temporarily prevent keyboard interrupts by delaying them until after the protected code."""
interrupted = False
original_handler = signal.getsignal(signal.SIGINT)
def handler(signum, frame):
nonlocal interrupted
interrupted = True
signal.signal(signal.SIGINT, handler)
try:
yield
finally:
signal.signal(signal.SIGINT, original_handler)
if interrupted:
raise KeyboardInterrupt
def main(args: Args):
# Make sure external camera is specified by user -- we only use one external camera for the policy
assert (
args.external_camera is not None and args.external_camera in ["left", "right"]
), f"Please specify an external camera to use for the policy, choose from ['left', 'right'], but got {args.external_camera}"
# Initialize the Panda environment. Using joint velocity action space and gripper position action space is very important.
env = RobotEnv(action_space="joint_velocity", gripper_action_space="position")
print("Created the droid env!")
# Connect to the policy server
policy_client = websocket_client_policy.WebsocketClientPolicy(args.remote_host, args.remote_port)
df = pd.DataFrame(columns=["success", "duration", "video_filename"])
while True:
instruction = input("Enter instruction: ")
# Rollout parameters
actions_from_chunk_completed = 0
pred_action_chunk = None
# Prepare to save video of rollout
timestamp = datetime.datetime.now().strftime("%Y_%m_%d_%H:%M:%S")
video = []
bar = tqdm.tqdm(range(args.max_timesteps))
print("Running rollout... press Ctrl+C to stop early.")
for t_step in bar:
start_time = time.time()
try:
# Get the current observation
curr_obs = _extract_observation(
args,
env.get_observation(),
# Save the first observation to disk
save_to_disk=t_step == 0,
)
video.append(curr_obs[f"{args.external_camera}_image"])
# Send websocket request to policy server if it's time to predict a new chunk
if actions_from_chunk_completed == 0 or actions_from_chunk_completed >= args.open_loop_horizon:
actions_from_chunk_completed = 0
# We resize images on the robot laptop to minimize the amount of data sent to the policy server
# and improve latency.
request_data = {
"observation/exterior_image_1_left": image_tools.resize_with_pad(
curr_obs[f"{args.external_camera}_image"], 224, 224
),
"observation/wrist_image_left": image_tools.resize_with_pad(curr_obs["wrist_image"], 224, 224),
"observation/joint_position": curr_obs["joint_position"],
"observation/gripper_position": curr_obs["gripper_position"],
"prompt": instruction,
}
# Wrap the server call in a context manager to prevent Ctrl+C from interrupting it
# Ctrl+C will be handled after the server call is complete
with prevent_keyboard_interrupt():
# this returns action chunk [10, 8] of 10 joint velocity actions (7) + gripper position (1)
pred_action_chunk = policy_client.infer(request_data)["actions"]
assert pred_action_chunk.shape == (10, 8)
# Select current action to execute from chunk
action = pred_action_chunk[actions_from_chunk_completed]
actions_from_chunk_completed += 1
# Binarize gripper action
if action[-1].item() > 0.5:
# action[-1] = 1.0
action = np.concatenate([action[:-1], np.ones((1,))])
else:
# action[-1] = 0.0
action = np.concatenate([action[:-1], np.zeros((1,))])
# clip all dimensions of action to [-1, 1]
action = np.clip(action, -1, 1)
env.step(action)
# Sleep to match DROID data collection frequency
elapsed_time = time.time() - start_time
if elapsed_time < 1 / DROID_CONTROL_FREQUENCY:
time.sleep(1 / DROID_CONTROL_FREQUENCY - elapsed_time)
except KeyboardInterrupt:
break
video = np.stack(video)
save_filename = "video_" + timestamp
ImageSequenceClip(list(video), fps=10).write_videofile(save_filename + ".mp4", codec="libx264")
success: str | float | None = None
while not isinstance(success, float):
success = input(
"Did the rollout succeed? (enter y for 100%, n for 0%), or a numeric value 0-100 based on the evaluation spec"
)
if success == "y":
success = 1.0
elif success == "n":
success = 0.0
success = float(success) / 100
if not (0 <= success <= 1):
print(f"Success must be a number in [0, 100] but got: {success * 100}")
df = df.append(
{
"success": success,
"duration": t_step,
"video_filename": save_filename,
},
ignore_index=True,
)
if input("Do one more eval? (enter y or n) ").lower() != "y":
break
env.reset()
os.makedirs("results", exist_ok=True)
timestamp = datetime.datetime.now().strftime("%I:%M%p_%B_%d_%Y")
csv_filename = os.path.join("results", f"eval_{timestamp}.csv")
df.to_csv(csv_filename)
print(f"Results saved to {csv_filename}")
def _extract_observation(args: Args, obs_dict, *, save_to_disk=False):
image_observations = obs_dict["image"]
left_image, right_image, wrist_image = None, None, None
for key in image_observations:
# Note the "left" below refers to the left camera in the stereo pair.
# The model is only trained on left stereo cams, so we only feed those.
if args.left_camera_id in key and "left" in key:
left_image = image_observations[key]
elif args.right_camera_id in key and "left" in key:
right_image = image_observations[key]
elif args.wrist_camera_id in key and "left" in key:
wrist_image = image_observations[key]
# Drop the alpha dimension
left_image = left_image[..., :3]
right_image = right_image[..., :3]
wrist_image = wrist_image[..., :3]
# Convert to RGB
left_image = left_image[..., ::-1]
right_image = right_image[..., ::-1]
wrist_image = wrist_image[..., ::-1]
# In addition to image observations, also capture the proprioceptive state
robot_state = obs_dict["robot_state"]
cartesian_position = np.array(robot_state["cartesian_position"])
joint_position = np.array(robot_state["joint_positions"])
gripper_position = np.array([robot_state["gripper_position"]])
# Save the images to disk so that they can be viewed live while the robot is running
# Create one combined image to make live viewing easy
if save_to_disk:
combined_image = np.concatenate([left_image, wrist_image, right_image], axis=1)
combined_image = Image.fromarray(combined_image)
combined_image.save("robot_camera_views.png")
return {
"left_image": left_image,
"right_image": right_image,
"wrist_image": wrist_image,
"cartesian_position": cartesian_position,
"joint_position": joint_position,
"gripper_position": gripper_position,
}
if __name__ == "__main__":
args: Args = tyro.cli(Args)
main(args)

View File

@@ -0,0 +1,137 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import dataclasses\n",
"\n",
"import jax\n",
"\n",
"from openpi.models import model as _model\n",
"from openpi.policies import droid_policy\n",
"from openpi.policies import policy_config as _policy_config\n",
"from openpi.shared import download\n",
"from openpi.training import config as _config\n",
"from openpi.training import data_loader as _data_loader"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Policy inference\n",
"\n",
"The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_fast_droid\")\n",
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_fast_droid\")\n",
"\n",
"# Create a trained policy.\n",
"policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
"\n",
"# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
"example = droid_policy.make_droid_example()\n",
"result = policy.infer(example)\n",
"\n",
"# Delete the policy to free up memory.\n",
"del policy\n",
"\n",
"print(\"Actions shape:\", result[\"actions\"].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Working with a live model\n",
"\n",
"\n",
"The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = _config.get_config(\"pi0_aloha_sim\")\n",
"\n",
"checkpoint_dir = download.maybe_download(\"gs://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
"key = jax.random.key(0)\n",
"\n",
"# Create a model from the checkpoint.\n",
"model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
"\n",
"# We can create fake observations and actions to test the model.\n",
"obs, act = config.model.fake_obs(), config.model.fake_act()\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"print(\"Loss shape:\", loss.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we are going to create a data loader and use a real batch of training data to compute the loss."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Reduce the batch size to reduce memory usage.\n",
"config = dataclasses.replace(config, batch_size=2)\n",
"\n",
"# Load a single batch of data. This is the same data that will be used during training.\n",
"# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
"# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
"loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
"obs, act = next(iter(loader))\n",
"\n",
"# Sample actions from the model.\n",
"loss = model.compute_loss(key, obs, act)\n",
"\n",
"# Delete the model to free up memory.\n",
"del model\n",
"\n",
"print(\"Loss shape:\", loss.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,59 @@
# Dockerfile for the LIBERO benchmark.
# Build the container:
# docker build . -t libero -f examples/libero/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app -v /tmp/.X11-unix:/tmp/.X11-unix:ro -e DISPLAY=$DISPLAY --gpus all libero /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
RUN apt-get update && \
apt-get install -y \
make \
g++ \
clang \
libosmesa6-dev \
libgl1-mesa-glx \
libglew-dev \
libglfw3-dev \
libgles2-mesa-dev \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/libero/requirements.txt /tmp/requirements.txt
COPY ./third_party/libero/requirements.txt /tmp/requirements-libero.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.8 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/requirements-libero.txt /tmp/openpi-client/pyproject.toml --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
ENV PYTHONPATH=/app:/app/packages/openpi-client/src:/app/third_party/libero
# Create a default config file to avoid an input prompt from LIBERO's init script.
# https://github.com/Lifelong-Robot-Learning/LIBERO/blob/master/libero/libero/__init__.py
ENV LIBERO_CONFIG_PATH=/tmp/libero
RUN mkdir -p /tmp/libero && cat <<'EOF' > /tmp/libero/config.yaml
benchmark_root: /app/third_party/libero/libero/libero
bddl_files: /app/third_party/libero/libero/libero/bddl_files
init_states: /app/third_party/libero/libero/libero/init_files
datasets: /app/third_party/libero/libero/datasets
assets: /app/third_party/libero/libero/libero/assets
EOF
CMD ["/bin/bash", "-c", "source /.venv/bin/activate && python examples/libero/main.py $CLIENT_ARGS"]

View File

@@ -0,0 +1,71 @@
# LIBERO Benchmark
This example runs the LIBERO benchmark: https://github.com/Lifelong-Robot-Learning/LIBERO
Note: When updating requirements.txt in this directory, there is an additional flag `--extra-index-url https://download.pytorch.org/whl/cu113` that must be added to the `uv pip compile` command.
This example requires git submodules to be initialized. Don't forget to run:
```bash
git submodule update --init --recursive
```
## With Docker (recommended)
```bash
# Grant access to the X11 server:
sudo xhost +local:docker
# To run with the default checkpoint and task suite:
SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
# To run with glx for Mujoco instead (use this if you have egl errors):
MUJOCO_GL=glx SERVER_ARGS="--env LIBERO" docker compose -f examples/libero/compose.yml up --build
```
You can customize the loaded checkpoint by providing additional `SERVER_ARGS` (see `scripts/serve_policy.py`), and the LIBERO task suite by providing additional `CLIENT_ARGS` (see `examples/libero/main.py`).
For example:
```bash
# To load a custom checkpoint (located in the top-level openpi/ directory):
export SERVER_ARGS="--env LIBERO policy:checkpoint --policy.config pi05_libero --policy.dir ./my_custom_checkpoint"
# To run the libero_10 task suite:
export CLIENT_ARGS="--args.task-suite-name libero_10"
```
## Without Docker (not recommended)
Terminal window 1:
```bash
# Create virtual environment
uv venv --python 3.8 examples/libero/.venv
source examples/libero/.venv/bin/activate
uv pip sync examples/libero/requirements.txt third_party/libero/requirements.txt --extra-index-url https://download.pytorch.org/whl/cu113 --index-strategy=unsafe-best-match
uv pip install -e packages/openpi-client
uv pip install -e third_party/libero
export PYTHONPATH=$PYTHONPATH:$PWD/third_party/libero
# Run the simulation
python examples/libero/main.py
# To run with glx for Mujoco instead (use this if you have egl errors):
MUJOCO_GL=glx python examples/libero/main.py
```
Terminal window 2:
```bash
# Run the server
uv run scripts/serve_policy.py --env LIBERO
```
## Results
If you want to reproduce the following numbers, you can evaluate the checkpoint at `gs://openpi-assets/checkpoints/pi05_libero/`. This
checkpoint was trained in openpi with the `pi05_libero` config.
| Model | Libero Spatial | Libero Object | Libero Goal | Libero 10 | Average |
|-------|---------------|---------------|-------------|-----------|---------|
| π0.5 @ 30k (finetuned) | 98.8 | 98.2 | 98.0 | 92.4 | 96.85

View File

@@ -0,0 +1,54 @@
# Run with:
# docker compose -f examples/libero/compose.yml up --build
services:
runtime:
image: libero
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/libero/Dockerfile
init: true
tty: true
network_mode: host
privileged: true
volumes:
- $PWD:/app
- ../../data:/data
- /tmp/.X11-unix:/tmp/.X11-unix:ro
environment:
- CLIENT_ARGS
- DISPLAY=$DISPLAY
- MUJOCO_GL=${MUJOCO_GL:-egl}
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,104 @@
"""
Minimal example script for converting a dataset to LeRobot format.
We use the Libero dataset (stored in RLDS) for this example, but it can be easily
modified for any other data you have saved in a custom format.
Usage:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data
If you want to push your dataset to the Hugging Face Hub, you can use the following command:
uv run examples/libero/convert_libero_data_to_lerobot.py --data_dir /path/to/your/data --push_to_hub
Note: to run the script, you need to install tensorflow_datasets:
`uv pip install tensorflow tensorflow_datasets`
You can download the raw Libero datasets from https://huggingface.co/datasets/openvla/modified_libero_rlds
The resulting dataset will get saved to the $HF_LEROBOT_HOME directory.
Running this conversion script will take approximately 30 minutes.
"""
import shutil
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import tensorflow_datasets as tfds
import tyro
REPO_NAME = "your_hf_username/libero" # Name of the output dataset, also used for the Hugging Face Hub
RAW_DATASET_NAMES = [
"libero_10_no_noops",
"libero_goal_no_noops",
"libero_object_no_noops",
"libero_spatial_no_noops",
] # For simplicity we will combine multiple Libero datasets into one training dataset
def main(data_dir: str, *, push_to_hub: bool = False):
# Clean up any existing dataset in the output directory
output_path = HF_LEROBOT_HOME / REPO_NAME
if output_path.exists():
shutil.rmtree(output_path)
# Create LeRobot dataset, define features to store
# OpenPi assumes that proprio is stored in `state` and actions in `action`
# LeRobot assumes that dtype of image data is `image`
dataset = LeRobotDataset.create(
repo_id=REPO_NAME,
robot_type="panda",
fps=10,
features={
"image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"wrist_image": {
"dtype": "image",
"shape": (256, 256, 3),
"names": ["height", "width", "channel"],
},
"state": {
"dtype": "float32",
"shape": (8,),
"names": ["state"],
},
"actions": {
"dtype": "float32",
"shape": (7,),
"names": ["actions"],
},
},
image_writer_threads=10,
image_writer_processes=5,
)
# Loop over raw Libero datasets and write episodes to the LeRobot dataset
# You can modify this for your own data format
for raw_dataset_name in RAW_DATASET_NAMES:
raw_dataset = tfds.load(raw_dataset_name, data_dir=data_dir, split="train")
for episode in raw_dataset:
for step in episode["steps"].as_numpy_iterator():
dataset.add_frame(
{
"image": step["observation"]["image"],
"wrist_image": step["observation"]["wrist_image"],
"state": step["observation"]["state"],
"actions": step["action"],
"task": step["language_instruction"].decode(),
}
)
dataset.save_episode()
# Optionally push to the Hugging Face Hub
if push_to_hub:
dataset.push_to_hub(
tags=["libero", "panda", "rlds"],
private=False,
push_videos=True,
license="apache-2.0",
)
if __name__ == "__main__":
tyro.cli(main)

View File

@@ -0,0 +1,219 @@
import collections
import dataclasses
import logging
import math
import pathlib
import imageio
from libero.libero import benchmark
from libero.libero import get_libero_path
from libero.libero.envs import OffScreenRenderEnv
import numpy as np
from openpi_client import image_tools
from openpi_client import websocket_client_policy as _websocket_client_policy
import tqdm
import tyro
LIBERO_DUMMY_ACTION = [0.0] * 6 + [-1.0]
LIBERO_ENV_RESOLUTION = 256 # resolution used to render training data
@dataclasses.dataclass
class Args:
#################################################################################################################
# Model server parameters
#################################################################################################################
host: str = "0.0.0.0"
port: int = 8000
resize_size: int = 224
replan_steps: int = 5
#################################################################################################################
# LIBERO environment-specific parameters
#################################################################################################################
task_suite_name: str = (
"libero_spatial" # Task suite. Options: libero_spatial, libero_object, libero_goal, libero_10, libero_90
)
num_steps_wait: int = 10 # Number of steps to wait for objects to stabilize i n sim
num_trials_per_task: int = 50 # Number of rollouts per task
#################################################################################################################
# Utils
#################################################################################################################
video_out_path: str = "data/libero/videos" # Path to save videos
seed: int = 7 # Random Seed (for reproducibility)
def eval_libero(args: Args) -> None:
# Set random seed
np.random.seed(args.seed)
# Initialize LIBERO task suite
benchmark_dict = benchmark.get_benchmark_dict()
task_suite = benchmark_dict[args.task_suite_name]()
num_tasks_in_suite = task_suite.n_tasks
logging.info(f"Task suite: {args.task_suite_name}")
pathlib.Path(args.video_out_path).mkdir(parents=True, exist_ok=True)
if args.task_suite_name == "libero_spatial":
max_steps = 220 # longest training demo has 193 steps
elif args.task_suite_name == "libero_object":
max_steps = 280 # longest training demo has 254 steps
elif args.task_suite_name == "libero_goal":
max_steps = 300 # longest training demo has 270 steps
elif args.task_suite_name == "libero_10":
max_steps = 520 # longest training demo has 505 steps
elif args.task_suite_name == "libero_90":
max_steps = 400 # longest training demo has 373 steps
else:
raise ValueError(f"Unknown task suite: {args.task_suite_name}")
client = _websocket_client_policy.WebsocketClientPolicy(args.host, args.port)
# Start evaluation
total_episodes, total_successes = 0, 0
for task_id in tqdm.tqdm(range(num_tasks_in_suite)):
# Get task
task = task_suite.get_task(task_id)
# Get default LIBERO initial states
initial_states = task_suite.get_task_init_states(task_id)
# Initialize LIBERO environment and task description
env, task_description = _get_libero_env(task, LIBERO_ENV_RESOLUTION, args.seed)
# Start episodes
task_episodes, task_successes = 0, 0
for episode_idx in tqdm.tqdm(range(args.num_trials_per_task)):
logging.info(f"\nTask: {task_description}")
# Reset environment
env.reset()
action_plan = collections.deque()
# Set initial states
obs = env.set_init_state(initial_states[episode_idx])
# Setup
t = 0
replay_images = []
logging.info(f"Starting episode {task_episodes+1}...")
while t < max_steps + args.num_steps_wait:
try:
# IMPORTANT: Do nothing for the first few timesteps because the simulator drops objects
# and we need to wait for them to fall
if t < args.num_steps_wait:
obs, reward, done, info = env.step(LIBERO_DUMMY_ACTION)
t += 1
continue
# Get preprocessed image
# IMPORTANT: rotate 180 degrees to match train preprocessing
img = np.ascontiguousarray(obs["agentview_image"][::-1, ::-1])
wrist_img = np.ascontiguousarray(obs["robot0_eye_in_hand_image"][::-1, ::-1])
img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(img, args.resize_size, args.resize_size)
)
wrist_img = image_tools.convert_to_uint8(
image_tools.resize_with_pad(wrist_img, args.resize_size, args.resize_size)
)
# Save preprocessed image for replay video
replay_images.append(img)
if not action_plan:
# Finished executing previous action chunk -- compute new chunk
# Prepare observations dict
element = {
"observation/image": img,
"observation/wrist_image": wrist_img,
"observation/state": np.concatenate(
(
obs["robot0_eef_pos"],
_quat2axisangle(obs["robot0_eef_quat"]),
obs["robot0_gripper_qpos"],
)
),
"prompt": str(task_description),
}
# Query model to get action
action_chunk = client.infer(element)["actions"]
assert (
len(action_chunk) >= args.replan_steps
), f"We want to replan every {args.replan_steps} steps, but policy only predicts {len(action_chunk)} steps."
action_plan.extend(action_chunk[: args.replan_steps])
action = action_plan.popleft()
# Execute action in environment
obs, reward, done, info = env.step(action.tolist())
if done:
task_successes += 1
total_successes += 1
break
t += 1
except Exception as e:
logging.error(f"Caught exception: {e}")
break
task_episodes += 1
total_episodes += 1
# Save a replay video of the episode
suffix = "success" if done else "failure"
task_segment = task_description.replace(" ", "_")
imageio.mimwrite(
pathlib.Path(args.video_out_path) / f"rollout_{task_segment}_{suffix}.mp4",
[np.asarray(x) for x in replay_images],
fps=10,
)
# Log current results
logging.info(f"Success: {done}")
logging.info(f"# episodes completed so far: {total_episodes}")
logging.info(f"# successes: {total_successes} ({total_successes / total_episodes * 100:.1f}%)")
# Log final results
logging.info(f"Current task success rate: {float(task_successes) / float(task_episodes)}")
logging.info(f"Current total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total success rate: {float(total_successes) / float(total_episodes)}")
logging.info(f"Total episodes: {total_episodes}")
def _get_libero_env(task, resolution, seed):
"""Initializes and returns the LIBERO environment, along with the task description."""
task_description = task.language
task_bddl_file = pathlib.Path(get_libero_path("bddl_files")) / task.problem_folder / task.bddl_file
env_args = {"bddl_file_name": task_bddl_file, "camera_heights": resolution, "camera_widths": resolution}
env = OffScreenRenderEnv(**env_args)
env.seed(seed) # IMPORTANT: seed seems to affect object positions even when using fixed initial state
return env, task_description
def _quat2axisangle(quat):
"""
Copied from robosuite: https://github.com/ARISE-Initiative/robosuite/blob/eafb81f54ffc104f905ee48a16bb15f059176ad3/robosuite/utils/transform_utils.py#L490C1-L512C55
"""
# clip quaternion
if quat[3] > 1.0:
quat[3] = 1.0
elif quat[3] < -1.0:
quat[3] = -1.0
den = np.sqrt(1.0 - quat[3] * quat[3])
if math.isclose(den, 0.0):
# This is (close to) a zero degree rotation, immediately return
return np.zeros(3)
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
tyro.cli(eval_libero)

View File

@@ -0,0 +1,11 @@
imageio[ffmpeg]
numpy==1.22.4
tqdm
tyro
PyYaml
opencv-python==4.6.0.66
torch==1.11.0+cu113
torchvision==0.12.0+cu113
torchaudio==0.11.0+cu113
robosuite==1.4.1
matplotlib==3.5.3

View File

@@ -0,0 +1,136 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/libero/requirements.in -o examples/libero/requirements.txt --python-version 3.8 --index-strategy=unsafe-best-match
absl-py==2.1.0
# via mujoco
certifi==2024.12.14
# via requests
charset-normalizer==3.4.0
# via requests
cycler==0.12.1
# via matplotlib
docstring-parser==0.16
# via tyro
etils==1.3.0
# via mujoco
eval-type-backport==0.2.0
# via tyro
evdev==1.7.1
# via pynput
fonttools==4.55.3
# via matplotlib
glfw==1.12.0
# via mujoco
idna==3.10
# via requests
imageio==2.35.1
# via -r examples/libero/requirements.in
imageio-ffmpeg==0.5.1
# via imageio
importlib-metadata==8.5.0
# via typeguard
importlib-resources==6.4.5
# via etils
kiwisolver==1.4.7
# via matplotlib
llvmlite==0.36.0
# via numba
markdown-it-py==3.0.0
# via rich
matplotlib==3.5.3
# via -r examples/libero/requirements.in
mdurl==0.1.2
# via markdown-it-py
mujoco==3.2.3
# via robosuite
numba==0.53.1
# via robosuite
numpy==1.22.4
# via
# -r examples/libero/requirements.in
# imageio
# matplotlib
# mujoco
# numba
# opencv-python
# robosuite
# scipy
# torchvision
opencv-python==4.6.0.66
# via
# -r examples/libero/requirements.in
# robosuite
packaging==24.2
# via matplotlib
pillow==10.4.0
# via
# imageio
# matplotlib
# robosuite
# torchvision
psutil==6.1.0
# via imageio
pygments==2.18.0
# via rich
pynput==1.7.7
# via robosuite
pyopengl==3.1.7
# via mujoco
pyparsing==3.1.4
# via matplotlib
python-dateutil==2.9.0.post0
# via matplotlib
python-xlib==0.33
# via pynput
pyyaml==6.0.2
# via -r examples/libero/requirements.in
requests==2.32.3
# via torchvision
rich==13.9.4
# via tyro
robosuite==1.4.1
# via -r examples/libero/requirements.in
scipy==1.10.1
# via robosuite
setuptools==75.3.0
# via
# imageio-ffmpeg
# numba
shtab==1.7.1
# via tyro
six==1.17.0
# via
# pynput
# python-dateutil
# python-xlib
termcolor==2.4.0
# via robosuite
torch==1.11.0+cu113
# via
# -r examples/libero/requirements.in
# torchaudio
# torchvision
torchaudio==0.11.0+cu113
# via -r examples/libero/requirements.in
torchvision==0.12.0+cu113
# via -r examples/libero/requirements.in
tqdm==4.67.1
# via -r examples/libero/requirements.in
typeguard==4.4.0
# via tyro
typing-extensions==4.12.2
# via
# etils
# rich
# torch
# torchvision
# typeguard
# tyro
tyro==0.9.2
# via -r examples/libero/requirements.in
urllib3==2.2.3
# via requests
zipp==3.20.2
# via
# etils
# importlib-metadata
# importlib-resources

View File

@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pathlib\n",
"\n",
"import numpy as np\n",
"\n",
"record_path = pathlib.Path(\"../policy_records\")\n",
"num_steps = len(list(record_path.glob(\"step_*.npy\")))\n",
"\n",
"records = []\n",
"for i in range(num_steps):\n",
" record = np.load(record_path / f\"step_{i}.npy\", allow_pickle=True).item()\n",
" records.append(record)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"length of records\", len(records))\n",
"print(\"keys in records\", records[0].keys())\n",
"\n",
"for k in records[0]:\n",
" print(f\"{k} shape: {records[0][k].shape}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"\n",
"\n",
"def get_image(step: int, idx: int = 0):\n",
" img = (255 * records[step][\"inputs/image\"]).astype(np.uint8)\n",
" return img[idx].transpose(1, 2, 0)\n",
"\n",
"\n",
"def show_image(step: int, idx_lst: list[int]):\n",
" imgs = [get_image(step, idx) for idx in idx_lst]\n",
" return Image.fromarray(np.hstack(imgs))\n",
"\n",
"\n",
"for i in range(2):\n",
" display(show_image(i, [0]))"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"\n",
"def get_axis(name, axis):\n",
" return np.array([record[name][axis] for record in records])\n",
"\n",
"\n",
"# qpos is [..., 14] of type float:\n",
"# 0-5: left arm joint angles\n",
"# 6: left arm gripper\n",
"# 7-12: right arm joint angles\n",
"# 13: right arm gripper\n",
"names = [(\"left_joint\", 6), (\"left_gripper\", 1), (\"right_joint\", 6), (\"right_gripper\", 1)]\n",
"\n",
"\n",
"def make_data():\n",
" cur_dim = 0\n",
" in_data = {}\n",
" out_data = {}\n",
" for name, dim_size in names:\n",
" for i in range(dim_size):\n",
" in_data[f\"{name}_{i}\"] = get_axis(\"inputs/qpos\", cur_dim)\n",
" out_data[f\"{name}_{i}\"] = get_axis(\"outputs/qpos\", cur_dim)\n",
" cur_dim += 1\n",
" return pd.DataFrame(in_data), pd.DataFrame(out_data)\n",
"\n",
"\n",
"in_data, out_data = make_data()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for name in in_data.columns:\n",
" data = pd.DataFrame({f\"in_{name}\": in_data[name], f\"out_{name}\": out_data[name]})\n",
" data.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -0,0 +1,32 @@
# Dockerfile for the simple client.
# Build the container:
# docker build . -t simple_client -f examples/simple_client/Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app simple_client /bin/bash
FROM python:3.7-slim
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Copy the requirements files so we can install dependencies.
# The rest of the project is mounted as a volume, so we don't need to rebuild on changes.
# This strategy is best for development-style usage.
COPY ./examples/simple_client/requirements.txt /tmp/requirements.txt
COPY ./packages/openpi-client/pyproject.toml /tmp/openpi-client/pyproject.toml
# Install python dependencies.
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN uv pip sync /tmp/requirements.txt /tmp/openpi-client/pyproject.toml
ENV PYTHONPATH=/app:/app/src:/app/packages/openpi-client/src
CMD /bin/bash -c "source /.venv/bin/activate && python examples/simple_client/main.py $SERVER_ARGS"

View File

@@ -0,0 +1,30 @@
# Simple Client
A minimal client that sends observations to the server and prints the inference rate.
You can specify which runtime environment to use using the `--env` flag. You can see the available options by running:
```bash
uv run examples/simple_client/main.py --help
```
## With Docker
```bash
export SERVER_ARGS="--env ALOHA_SIM"
docker compose -f examples/simple_client/compose.yml up --build
```
## Without Docker
Terminal window 1:
```bash
uv run examples/simple_client/main.py --env DROID
```
Terminal window 2:
```bash
uv run scripts/serve_policy.py --env DROID
```

View File

@@ -0,0 +1,42 @@
# Run with:
# docker compose -f examples/simple_client/compose.yml up --build
services:
runtime:
image: simple_client
depends_on:
- openpi_server
build:
context: ../..
dockerfile: examples/simple_client/Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
environment:
- SERVER_ARGS
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,187 @@
import dataclasses
import enum
import logging
import pathlib
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import polars as pl
import rich
import tqdm
import tyro
logger = logging.getLogger(__name__)
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Args:
"""Command line arguments."""
# Host and port to connect to the server.
host: str = "0.0.0.0"
# Port to connect to the server. If None, the server will use the default port.
port: int | None = 8000
# API key to use for the server.
api_key: str | None = None
# Number of steps to run the policy for.
num_steps: int = 20
# Path to save the timings to a parquet file. (e.g., timing.parquet)
timing_file: pathlib.Path | None = None
# Environment to run the policy in.
env: EnvMode = EnvMode.ALOHA_SIM
class TimingRecorder:
"""Records timing measurements for different keys."""
def __init__(self) -> None:
self._timings: dict[str, list[float]] = {}
def record(self, key: str, time_ms: float) -> None:
"""Record a timing measurement for the given key."""
if key not in self._timings:
self._timings[key] = []
self._timings[key].append(time_ms)
def get_stats(self, key: str) -> dict[str, float]:
"""Get statistics for the given key."""
times = self._timings[key]
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"p25": float(np.quantile(times, 0.25)),
"p50": float(np.quantile(times, 0.50)),
"p75": float(np.quantile(times, 0.75)),
"p90": float(np.quantile(times, 0.90)),
"p95": float(np.quantile(times, 0.95)),
"p99": float(np.quantile(times, 0.99)),
}
def print_all_stats(self) -> None:
"""Print statistics for all keys in a concise format."""
table = rich.table.Table(
title="[bold blue]Timing Statistics[/bold blue]",
show_header=True,
header_style="bold white",
border_style="blue",
title_justify="center",
)
# Add metric column with custom styling
table.add_column("Metric", style="cyan", justify="left", no_wrap=True)
# Add statistical columns with consistent styling
stat_columns = [
("Mean", "yellow", "mean"),
("Std", "yellow", "std"),
("P25", "magenta", "p25"),
("P50", "magenta", "p50"),
("P75", "magenta", "p75"),
("P90", "magenta", "p90"),
("P95", "magenta", "p95"),
("P99", "magenta", "p99"),
]
for name, style, _ in stat_columns:
table.add_column(name, justify="right", style=style, no_wrap=True)
# Add rows for each metric with formatted values
for key in sorted(self._timings.keys()):
stats = self.get_stats(key)
values = [f"{stats[key]:.1f}" for _, _, key in stat_columns]
table.add_row(key, *values)
# Print with custom console settings
console = rich.console.Console(width=None, highlight=True)
console.print(table)
def write_parquet(self, path: pathlib.Path) -> None:
"""Save the timings to a parquet file."""
logger.info(f"Writing timings to {path}")
frame = pl.DataFrame(self._timings)
path.parent.mkdir(parents=True, exist_ok=True)
frame.write_parquet(path)
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
api_key=args.api_key,
)
logger.info(f"Server metadata: {policy.get_server_metadata()}")
# Send a few observations to make sure the model is loaded.
for _ in range(2):
policy.infer(obs_fn())
timing_recorder = TimingRecorder()
for _ in tqdm.trange(args.num_steps, desc="Running policy"):
inference_start = time.time()
action = policy.infer(obs_fn())
timing_recorder.record("client_infer_ms", 1000 * (time.time() - inference_start))
for key, value in action.get("server_timing", {}).items():
timing_recorder.record(f"server_{key}", value)
for key, value in action.get("policy_timing", {}).items():
timing_recorder.record(f"policy_{key}", value)
timing_recorder.print_all_stats()
if args.timing_file is not None:
timing_recorder.write_parquet(args.timing_file)
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))

View File

@@ -0,0 +1,5 @@
numpy>=1.22.4,<2.0.0
rich
tqdm
tyro
polars

View File

@@ -0,0 +1,30 @@
# This file was autogenerated by uv via the following command:
# uv pip compile examples/simple_client/requirements.in -o examples/simple_client/requirements.txt --python-version 3.11.9
docstring-parser==0.16
# via tyro
markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
numpy==1.26.4
# via -r examples/simple_client/requirements.in
polars==1.30.0
# via -r examples/simple_client/requirements.in
pygments==2.19.1
# via rich
rich==14.0.0
# via
# -r examples/simple_client/requirements.in
# tyro
shtab==1.7.2
# via tyro
tqdm==4.67.1
# via -r examples/simple_client/requirements.in
typeguard==4.4.2
# via tyro
typing-extensions==4.13.2
# via
# typeguard
# tyro
tyro==0.9.22
# via -r examples/simple_client/requirements.in

View File

@@ -0,0 +1,142 @@
# UR5 Example
Below we provide an outline of how to implement the key components mentioned in the "Finetune on your data" section of the [README](../README.md) for finetuning on UR5 datasets.
First, we will define the `UR5Inputs` and `UR5Outputs` classes, which map the UR5 environment to the model and vice versa. Check the corresponding files in `src/openpi/policies/libero_policy.py` for comments explaining each line.
```python
@dataclasses.dataclass(frozen=True)
class UR5Inputs(transforms.DataTransformFn):
model_type: _model.ModelType = _model.ModelType.PI0
def __call__(self, data: dict) -> dict:
# First, concatenate the joints and gripper into the state vector.
state = np.concatenate([data["joints"], data["gripper"]])
# Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
# stores as float32 (C,H,W), gets skipped for policy inference.
base_image = _parse_image(data["base_rgb"])
wrist_image = _parse_image(data["wrist_rgb"])
# Create inputs dict.
inputs = {
"state": state,
"image": {
"base_0_rgb": base_image,
"left_wrist_0_rgb": wrist_image,
# Since there is no right wrist, replace with zeros
"right_wrist_0_rgb": np.zeros_like(base_image),
},
"image_mask": {
"base_0_rgb": np.True_,
"left_wrist_0_rgb": np.True_,
# Since the "slot" for the right wrist is not used, this mask is set
# to False
"right_wrist_0_rgb": np.True_ if self.model_type == _model.ModelType.PI0_FAST else np.False_,
},
}
if "actions" in data:
inputs["actions"] = data["actions"]
# Pass the prompt (aka language instruction) to the model.
if "prompt" in data:
inputs["prompt"] = data["prompt"]
return inputs
@dataclasses.dataclass(frozen=True)
class UR5Outputs(transforms.DataTransformFn):
def __call__(self, data: dict) -> dict:
# Since the robot has 7 action dimensions (6 DoF + gripper), return the first 7 dims
return {"actions": np.asarray(data["actions"][:, :7])}
```
Next, we will define the `UR5DataConfig` class, which defines how to process raw UR5 data from LeRobot dataset for training. For a full example, see the `LeRobotLiberoDataConfig` config in the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py).
```python
@dataclasses.dataclass(frozen=True)
class LeRobotUR5DataConfig(DataConfigFactory):
@override
def create(self, assets_dirs: pathlib.Path, model_config: _model.BaseModelConfig) -> DataConfig:
# Boilerplate for remapping keys from the LeRobot dataset. We assume no renaming needed here.
repack_transform = _transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"base_rgb": "image",
"wrist_rgb": "wrist_image",
"joints": "joints",
"gripper": "gripper",
"prompt": "prompt",
}
)
]
)
# These transforms are the ones we wrote earlier.
data_transforms = _transforms.Group(
inputs=[UR5Inputs(action_dim=model_config.action_dim, model_type=model_config.model_type)],
outputs=[UR5Outputs()],
)
# Convert absolute actions to delta actions.
# By convention, we do not convert the gripper action (7th dimension).
delta_action_mask = _transforms.make_bool_mask(6, -1)
data_transforms = data_transforms.push(
inputs=[_transforms.DeltaActions(delta_action_mask)],
outputs=[_transforms.AbsoluteActions(delta_action_mask)],
)
# Model transforms include things like tokenizing the prompt and action targets
# You do not need to change anything here for your own dataset.
model_transforms = ModelTransformFactory()(model_config)
# We return all data transforms for training and inference. No need to change anything here.
return dataclasses.replace(
self.create_base_config(assets_dirs),
repack_transforms=repack_transform,
data_transforms=data_transforms,
model_transforms=model_transforms,
)
```
Finally, we define the TrainConfig for our UR5 dataset. Here, we define a config for fine-tuning pi0 on our UR5 dataset. See the [training config file](https://github.com/physical-intelligence/openpi/blob/main/src/openpi/training/config.py) for more examples, e.g. for pi0-FAST or for LoRA fine-tuning.
```python
TrainConfig(
name="pi0_ur5",
model=pi0.Pi0Config(),
data=LeRobotUR5DataConfig(
repo_id="your_username/ur5_dataset",
# This config lets us reload the UR5 normalization stats from the base model checkpoint.
# Reloading normalization stats can help transfer pre-trained models to new environments.
# See the [norm_stats.md](../docs/norm_stats.md) file for more details.
assets=AssetsConfig(
assets_dir="gs://openpi-assets/checkpoints/pi0_base/assets",
asset_id="ur5e",
),
base_config=DataConfig(
# This flag determines whether we load the prompt (i.e. the task instruction) from the
# ``task`` field in the LeRobot dataset. The recommended setting is True.
prompt_from_task=True,
),
),
# Load the pi0 base model checkpoint.
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=30_000,
)
```

View File

@@ -0,0 +1,23 @@
[project]
name = "openpi-client"
version = "0.1.0"
requires-python = ">=3.7"
dependencies = [
"dm-tree>=0.1.8",
"msgpack>=1.0.5",
"numpy>=1.22.4,<2.0.0",
"pillow>=9.0.0",
"tree>=0.2.4",
"websockets>=11.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.uv]
dev-dependencies = ["pytest>=8.3.4"]
[tool.ruff]
line-length = 120
target-version = "py37"

View File

@@ -0,0 +1 @@
__version__ = "0.1.0"

View File

@@ -0,0 +1,50 @@
from typing import Dict
import numpy as np
import tree
from typing_extensions import override
from openpi_client import base_policy as _base_policy
class ActionChunkBroker(_base_policy.BasePolicy):
"""Wraps a policy to return action chunks one-at-a-time.
Assumes that the first dimension of all action fields is the chunk size.
A new inference call to the inner policy is only made when the current
list of chunks is exhausted.
"""
def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
self._last_results: Dict[str, np.ndarray] | None = None
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0
def slicer(x):
if isinstance(x, np.ndarray):
return x[self._cur_step, ...]
else:
return x
results = tree.map_structure(slicer, self._last_results)
self._cur_step += 1
if self._cur_step >= self._action_horizon:
self._last_results = None
return results
@override
def reset(self) -> None:
self._policy.reset()
self._last_results = None
self._cur_step = 0

View File

@@ -0,0 +1,12 @@
import abc
from typing import Dict
class BasePolicy(abc.ABC):
@abc.abstractmethod
def infer(self, obs: Dict) -> Dict:
"""Infer actions from observations."""
def reset(self) -> None:
"""Reset the policy to its initial state."""
pass

View File

@@ -0,0 +1,58 @@
import numpy as np
from PIL import Image
def convert_to_uint8(img: np.ndarray) -> np.ndarray:
"""Converts an image to uint8 if it is a float image.
This is important for reducing the size of the image when sending it over the network.
"""
if np.issubdtype(img.dtype, np.floating):
img = (255 * img).astype(np.uint8)
return img
def resize_with_pad(images: np.ndarray, height: int, width: int, method=Image.BILINEAR) -> np.ndarray:
"""Replicates tf.image.resize_with_pad for multiple images using PIL. Resizes a batch of images to a target height.
Args:
images: A batch of images in [..., height, width, channel] format.
height: The target height of the image.
width: The target width of the image.
method: The interpolation method to use. Default is bilinear.
Returns:
The resized images in [..., height, width, channel].
"""
# If the images are already the correct size, return them as is.
if images.shape[-3:-1] == (height, width):
return images
original_shape = images.shape
images = images.reshape(-1, *original_shape[-3:])
resized = np.stack([_resize_with_pad_pil(Image.fromarray(im), height, width, method=method) for im in images])
return resized.reshape(*original_shape[:-3], *resized.shape[-3:])
def _resize_with_pad_pil(image: Image.Image, height: int, width: int, method: int) -> Image.Image:
"""Replicates tf.image.resize_with_pad for one image using PIL. Resizes an image to a target height and
width without distortion by padding with zeros.
Unlike the jax version, note that PIL uses [width, height, channel] ordering instead of [batch, h, w, c].
"""
cur_width, cur_height = image.size
if cur_width == width and cur_height == height:
return image # No need to resize if the image is already the correct size.
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_image = image.resize((resized_width, resized_height), resample=method)
zero_image = Image.new(resized_image.mode, (width, height), 0)
pad_height = max(0, int((height - resized_height) / 2))
pad_width = max(0, int((width - resized_width) / 2))
zero_image.paste(resized_image, (pad_width, pad_height))
assert zero_image.size == (width, height)
return zero_image

View File

@@ -0,0 +1,37 @@
import numpy as np
import openpi_client.image_tools as image_tools
def test_resize_with_pad_shapes():
# Test case 1: Resize image with larger dimensions
images = np.zeros((2, 10, 10, 3), dtype=np.uint8) # Input images of shape (batch_size, height, width, channels)
height = 20
width = 20
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (2, height, width, 3)
assert np.all(resized_images == 0)
# Test case 2: Resize image with smaller dimensions
images = np.zeros((3, 30, 30, 3), dtype=np.uint8)
height = 15
width = 15
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (3, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with the same dimensions
images = np.zeros((1, 50, 50, 3), dtype=np.uint8)
height = 50
width = 50
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)
# Test case 3: Resize image with odd-numbered padding
images = np.zeros((1, 256, 320, 3), dtype=np.uint8)
height = 60
width = 80
resized_images = image_tools.resize_with_pad(images, height, width)
assert resized_images.shape == (1, height, width, 3)
assert np.all(resized_images == 0)

View File

@@ -0,0 +1,57 @@
"""Adds NumPy array support to msgpack.
msgpack is good for (de)serializing data over a network for multiple reasons:
- msgpack is secure (as opposed to pickle/dill/etc which allow for arbitrary code execution)
- msgpack is widely used and has good cross-language support
- msgpack does not require a schema (as opposed to protobuf/flatbuffers/etc) which is convenient in dynamically typed
languages like Python and JavaScript
- msgpack is fast and efficient (as opposed to readable formats like JSON/YAML/etc); I found that msgpack was ~4x faster
than pickle for serializing large arrays using the below strategy
The code below is adapted from https://github.com/lebedov/msgpack-numpy. The reason not to use that library directly is
that it falls back to pickle for object arrays.
"""
import functools
import msgpack
import numpy as np
def pack_array(obj):
if (isinstance(obj, (np.ndarray, np.generic))) and obj.dtype.kind in ("V", "O", "c"):
raise ValueError(f"Unsupported dtype: {obj.dtype}")
if isinstance(obj, np.ndarray):
return {
b"__ndarray__": True,
b"data": obj.tobytes(),
b"dtype": obj.dtype.str,
b"shape": obj.shape,
}
if isinstance(obj, np.generic):
return {
b"__npgeneric__": True,
b"data": obj.item(),
b"dtype": obj.dtype.str,
}
return obj
def unpack_array(obj):
if b"__ndarray__" in obj:
return np.ndarray(buffer=obj[b"data"], dtype=np.dtype(obj[b"dtype"]), shape=obj[b"shape"])
if b"__npgeneric__" in obj:
return np.dtype(obj[b"dtype"]).type(obj[b"data"])
return obj
Packer = functools.partial(msgpack.Packer, default=pack_array)
packb = functools.partial(msgpack.packb, default=pack_array)
Unpacker = functools.partial(msgpack.Unpacker, object_hook=unpack_array)
unpackb = functools.partial(msgpack.unpackb, object_hook=unpack_array)

View File

@@ -0,0 +1,45 @@
import numpy as np
import pytest
import tree
from openpi_client import msgpack_numpy
def _check(expected, actual):
if isinstance(expected, np.ndarray):
assert expected.shape == actual.shape
assert expected.dtype == actual.dtype
assert np.array_equal(expected, actual, equal_nan=expected.dtype.kind == "f")
else:
assert expected == actual
@pytest.mark.parametrize(
"data",
[
1, # int
1.0, # float
"hello", # string
np.bool_(True), # boolean scalar
np.array([1, 2, 3])[0], # int scalar
np.str_("asdf"), # string scalar
[1, 2, 3], # list
{"key": "value"}, # dict
{"key": [1, 2, 3]}, # nested dict
np.array(1.0), # 0D array
np.array([1, 2, 3], dtype=np.int32), # 1D integer array
np.array(["asdf", "qwer"]), # string array
np.array([True, False]), # boolean array
np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32), # 2D float array
np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.int16), # 3D integer array
np.array([np.nan, np.inf, -np.inf]), # special float values
{"arr": np.array([1, 2, 3]), "nested": {"arr": np.array([4, 5, 6])}}, # nested dict with arrays
[np.array([1, 2]), np.array([3, 4])], # list of arrays
np.zeros((3, 4, 5), dtype=np.float32), # 3D zeros
np.ones((2, 3), dtype=np.float64), # 2D ones with double precision
],
)
def test_pack_unpack(data):
packed = msgpack_numpy.packb(data)
unpacked = msgpack_numpy.unpackb(packed)
tree.map_structure(_check, data, unpacked)

View File

@@ -0,0 +1,17 @@
import abc
class Agent(abc.ABC):
"""An Agent is the thing with agency, i.e. the entity that makes decisions.
Agents receive observations about the state of the world, and return actions
to take in response.
"""
@abc.abstractmethod
def get_action(self, observation: dict) -> dict:
"""Query the agent for the next action."""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the agent to its initial state."""

View File

@@ -0,0 +1,18 @@
from typing_extensions import override
from openpi_client import base_policy as _base_policy
from openpi_client.runtime import agent as _agent
class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""
def __init__(self, policy: _base_policy.BasePolicy) -> None:
self._policy = policy
@override
def get_action(self, observation: dict) -> dict:
return self._policy.infer(observation)
def reset(self) -> None:
self._policy.reset()

View File

@@ -0,0 +1,32 @@
import abc
class Environment(abc.ABC):
"""An Environment represents the robot and the environment it inhabits.
The primary contract of environments is that they can be queried for observations
about their state, and have actions applied to them to change that state.
"""
@abc.abstractmethod
def reset(self) -> None:
"""Reset the environment to its initial state.
This will be called once before starting each episode.
"""
@abc.abstractmethod
def is_episode_complete(self) -> bool:
"""Allow the environment to signal that the episode is complete.
This will be called after each step. It should return `True` if the episode is
complete (either successfully or unsuccessfully), and `False` otherwise.
"""
@abc.abstractmethod
def get_observation(self) -> dict:
"""Query the environment for the current state."""
@abc.abstractmethod
def apply_action(self, action: dict) -> None:
"""Take an action in the environment."""

View File

@@ -0,0 +1,92 @@
import logging
import threading
import time
from openpi_client.runtime import agent as _agent
from openpi_client.runtime import environment as _environment
from openpi_client.runtime import subscriber as _subscriber
class Runtime:
"""The core module orchestrating interactions between key components of the system."""
def __init__(
self,
environment: _environment.Environment,
agent: _agent.Agent,
subscribers: list[_subscriber.Subscriber],
max_hz: float = 0,
num_episodes: int = 1,
max_episode_steps: int = 0,
) -> None:
self._environment = environment
self._agent = agent
self._subscribers = subscribers
self._max_hz = max_hz
self._num_episodes = num_episodes
self._max_episode_steps = max_episode_steps
self._in_episode = False
self._episode_steps = 0
def run(self) -> None:
"""Runs the runtime loop continuously until stop() is called or the environment is done."""
for _ in range(self._num_episodes):
self._run_episode()
# Final reset, this is important for real environments to move the robot to its home position.
self._environment.reset()
def run_in_new_thread(self) -> threading.Thread:
"""Runs the runtime loop in a new thread."""
thread = threading.Thread(target=self.run)
thread.start()
return thread
def mark_episode_complete(self) -> None:
"""Marks the end of an episode."""
self._in_episode = False
def _run_episode(self) -> None:
"""Runs a single episode."""
logging.info("Starting episode...")
self._environment.reset()
self._agent.reset()
for subscriber in self._subscribers:
subscriber.on_episode_start()
self._in_episode = True
self._episode_steps = 0
step_time = 1 / self._max_hz if self._max_hz > 0 else 0
last_step_time = time.time()
while self._in_episode:
self._step()
self._episode_steps += 1
# Sleep to maintain the desired frame rate
now = time.time()
dt = now - last_step_time
if dt < step_time:
time.sleep(step_time - dt)
last_step_time = time.time()
else:
last_step_time = now
logging.info("Episode completed.")
for subscriber in self._subscribers:
subscriber.on_episode_end()
def _step(self) -> None:
"""A single step of the runtime loop."""
observation = self._environment.get_observation()
action = self._agent.get_action(observation)
self._environment.apply_action(action)
for subscriber in self._subscribers:
subscriber.on_step(observation, action)
if self._environment.is_episode_complete() or (
self._max_episode_steps > 0 and self._episode_steps >= self._max_episode_steps
):
self.mark_episode_complete()

View File

@@ -0,0 +1,20 @@
import abc
class Subscriber(abc.ABC):
"""Subscribes to events in the runtime.
Subscribers can be used to save data, visualize, etc.
"""
@abc.abstractmethod
def on_episode_start(self) -> None:
"""Called when an episode starts."""
@abc.abstractmethod
def on_step(self, observation: dict, action: dict) -> None:
"""Append a step to the episode."""
@abc.abstractmethod
def on_episode_end(self) -> None:
"""Called when an episode ends."""

View File

@@ -0,0 +1,55 @@
import logging
import time
from typing import Dict, Optional, Tuple
from typing_extensions import override
import websockets.sync.client
from openpi_client import base_policy as _base_policy
from openpi_client import msgpack_numpy
class WebsocketClientPolicy(_base_policy.BasePolicy):
"""Implements the Policy interface by communicating with a server over websocket.
See WebsocketPolicyServer for a corresponding server implementation.
"""
def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
self._uri = f"ws://{host}"
if port is not None:
self._uri += f":{port}"
self._packer = msgpack_numpy.Packer()
self._api_key = api_key
self._ws, self._server_metadata = self._wait_for_server()
def get_server_metadata(self) -> Dict:
return self._server_metadata
def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
logging.info(f"Waiting for server at {self._uri}...")
while True:
try:
headers = {"Authorization": f"Api-Key {self._api_key}"} if self._api_key else None
conn = websockets.sync.client.connect(
self._uri, compression=None, max_size=None, additional_headers=headers
)
metadata = msgpack_numpy.unpackb(conn.recv())
return conn, metadata
except ConnectionRefusedError:
logging.info("Still waiting for server...")
time.sleep(5)
@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
data = self._packer.pack(obs)
self._ws.send(data)
response = self._ws.recv()
if isinstance(response, str):
# we're expecting bytes; if the server sends a string, it's an error.
raise RuntimeError(f"Error in inference server:\n{response}")
return msgpack_numpy.unpackb(response)
@override
def reset(self) -> None:
pass

View File

@@ -0,0 +1,136 @@
[project]
name = "openpi"
version = "0.1.0"
description = "Physical Intelligence open source repo"
readme = "README.md"
requires-python = ">=3.11"
license = { file = "LICENSE" }
dependencies = [
"augmax>=0.3.4",
"dm-tree>=0.1.8",
"einops>=0.8.0",
"equinox>=0.11.8",
"flatbuffers>=24.3.25",
"flax==0.10.2",
"fsspec[gcs]>=2024.6.0",
"gym-aloha>=0.1.1",
"imageio>=2.36.1",
"jax[cuda12]==0.5.3",
"jaxtyping==0.2.36",
"ml_collections==1.0.0",
"numpy>=1.22.4,<2.0.0",
"numpydantic>=1.6.6",
"opencv-python>=4.10.0.84",
"openpi-client",
"orbax-checkpoint==0.11.13",
"pillow>=11.0.0",
"sentencepiece>=0.2.0",
"torch==2.7.1",
"tqdm-loggable>=0.2",
"typing-extensions>=4.12.2",
"tyro>=0.9.5",
"wandb>=0.19.1",
"filelock>=3.16.1",
"beartype==0.19.0",
"treescope>=0.1.7",
"transformers==4.53.2",
"rich>=14.0.0",
"polars>=1.30.0",
]
[project.urls]
Repository = "https://github.com/Physical-Intelligence/openpi"
[dependency-groups]
dev = [
"pytest>=8.3.4",
"ruff>=0.8.6",
"pre-commit>=4.0.1",
"ipykernel>=6.29.5",
"ipywidgets>=8.1.5",
"matplotlib>=3.10.0",
"pynvml>=12.0.0",
]
rlds = [
"dlimp",
"tensorflow-cpu==2.15.0",
"tensorflow-datasets==4.9.9",
]
[tool.uv]
override-dependencies = ["ml-dtypes==0.4.1", "tensorstore==0.1.74"]
[tool.uv.sources]
openpi-client = { workspace = true }
lerobot = { git = "https://github.com/huggingface/lerobot", rev = "0cf864870cf29f4738d3ade893e6fd13fbd7cdb5" }
dlimp = { git = "https://github.com/kvablack/dlimp", rev = "ad72ce3a9b414db2185bc0b38461d4101a65477a" }
[tool.uv.workspace]
members = ["packages/*"]
[tool.ruff]
line-length = 120
target-version = "py311"
extend-exclude = ["docker", "third_party", "src/openpi/models_pytorch/transformers_replace/*"]
[tool.ruff.lint]
# https://docs.astral.sh/ruff/rules/
select = [
"B",
"C4",
"DTZ",
"E4",
"E7",
"E9",
"F",
"FBT",
"FURB",
"I",
"ICN",
"ISC",
"LOG",
"N",
"PD",
"PERF",
"PIE",
"PLC",
"PLE",
"PLR1",
"PLR5",
"PLW",
"PT",
"Q",
"RET",
"RUF",
"SIM",
"SLF",
"T10",
"T20",
"UP",
"W",
]
ignore = [
"F722", # Conflicts with array typing.
"T201", # We use print statements.
"PD008", # Lots of false positives.
"ISC001", # Disabling to support ruff format.
"LOG015", # Use logger.info.
]
unfixable = [
"B905", # Fix defaults to strict=False, which is not what we want.
]
[tool.ruff.lint.isort]
force-single-line = true
force-sort-within-sections = true
single-line-exclusions = ["collections.abc", "typing", "typing_extensions"]
known-third-party = ["wandb"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.pytest.ini_options]
markers = ["manual: should be run manually."]
testpaths = ["src", "scripts", "packages"]

View File

@@ -0,0 +1,36 @@
augmax>=0.3.4
dm-tree>=0.1.8
einops>=0.8.0
equinox>=0.11.8
flatbuffers>=24.3.25
flax==0.10.2
fsspec[gcs]>=2024.6.0
gym-aloha>=0.1.1
imageio>=2.36.1
jax[cuda12]==0.5.3
jaxtyping==0.2.36
ml_collections==1.0.0
numpy>=1.22.4,<2.0.0
numpydantic>=1.6.6
opencv-python>=4.10.0.84
orbax-checkpoint==0.11.13
pillow>=11.0.0
sentencepiece>=0.2.0
tqdm-loggable>=0.2
typing-extensions>=4.12.2
tyro>=0.9.5
wandb>=0.19.1
filelock>=3.16.1
beartype==0.19.0
treescope>=0.1.7
transformers==4.53.2
rich>=14.0.0
polars>=1.30.0
ml-dtypes==0.5.3
tensorstore==0.1.74
# tensorflow==2.20.0
tensorflow-datasets==4.9.9
lmdb==1.7.3
pytest==8.4.1
nvidia-cudnn-cu12==9.10.2.21
# dlimp

View File

@@ -0,0 +1,218 @@
"""Compute normalization statistics for real-world tasks.
This script is used to compute the normalization statistics for a given real-world task. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config directory.
"""
import os
import glob
import numpy as np
import tqdm
import tyro
import json
import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.mixture_dataset as _mixture_dataset
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms
### training config ###
import openpi.training.weight_loaders as weight_loaders
import openpi.models.pi0_config as pi0_config
from openpi.training.config import MultiLeRobotReala2dDataConfig, MultiLeRobotRealArxLift2DataConfig, MultiDataConfig, DataConfig, TrainConfig
from pdb import set_trace
class RemoveStrings(transforms.DataTransformFn):
def __call__(self, x: dict) -> dict:
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
def create_torch_dataloader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
model_config: _model.BaseModelConfig,
num_workers: int,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
dataset = _mixture_dataset.TransformedDataset(
dataset,
[
*data_config[0].repack_transforms.inputs,
*data_config[0].data_transforms.inputs,
RemoveStrings(),
],
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
def main(dataset_path, robot_name, task_name, save_path):
if robot_name == "lift2" or robot_name == "split_aloha" or robot_name == "acone":
config = TrainConfig(
name="lift2",
model=pi0_config.Pi0Config(),
data=[
MultiLeRobotRealArxLift2DataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=False,
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"left_joint": "states.left_joint.position",
"right_joint": "states.right_joint.position",
"left_gripper": "states.left_gripper.position",
"right_gripper": "states.right_gripper.position"
},
"action_dict": {
"left_joint": "actions.left_joint.position",
"right_joint": "actions.right_joint.position",
"left_gripper": "actions.left_gripper.position",
"right_gripper": "actions.right_gripper.position"
},
"prompt": "task"
}
)
]
)
),
],
# pretrain model path
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
elif robot_name == "genie1":
config = TrainConfig(
name="genie1",
model=pi0_config.Pi0Config(),
data=[
MultiLeRobotReala2dDataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=False,
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"joint": "observation.states.joint.position",
"gripper": "observation.states.effector.position",
},
"action_dict": {
"joint": "actions.joint.position",
"gripper": "actions.effector.position",
},
"prompt": "task"
}
)
]
)
),
],
# pretrain model path
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
data_config = config.data[0].create(config.model)
print("done")
output_path = os.path.join(save_path, robot_name, task_name)
stats_json_path = os.path.join(output_path, "norm_stats.json")
if os.path.isfile(stats_json_path):
with open(stats_json_path, 'r', encoding='utf-8') as f:
json.load(f)
return True
data_loader, num_batches = create_torch_dataloader(
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
step_id = 0
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
step_id += 1
for key in keys:
stats[key].update(np.asarray(batch[key]))
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
def check_lerobot_repo(repo_dir: str):
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
print(repo_dir, "true")
return True
else:
print(repo_dir, "false")
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--task_path", type=str, default="data/InternData-A1/real/genie1/Put_the_pen_from_the_table_into_the_pen_holder/*")
parser.add_argument("--robot_name", type=str, default="genie1")
parser.add_argument("--save_path", type=str, default="stats/real")
args, unknown = parser.parse_known_args()
dataset_path=args.task_path
save_path = args.save_path
parts = dataset_path.split("/")
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
if robot_idx is None:
raise ValueError(
f"Cannot find robot name in path. Expected {args.robot_name}, "
f"but got path: {dataset_path}"
)
if robot_idx + 1 >= len(parts):
raise ValueError(
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
)
robot_name = parts[robot_idx]
task_name = parts[robot_idx + 1]
try:
main(dataset_path, robot_name, task_name, save_path)
except:
print(dataset_path)

View File

@@ -0,0 +1,314 @@
"""Compute normalization statistics for interndata-a1 sim tasks.
This script is used to compute the normalization statistics for interndata-a1 sim tasks. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config assets directory.
"""
import os
import glob
import numpy as np
import tqdm
import tyro
import json
import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.mixture_dataset as _mixture_dataset
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms
### training config ###
import openpi.training.weight_loaders as weight_loaders
import openpi.models.pi0_config as pi0_config
from openpi.training.config import MultiSimGenieDataConfig, MultiSimSplitAlohaDataConfig, MultiSimFrankaDataConfig, MultiDataConfig, DataConfig, TrainConfig
from pdb import set_trace
class RemoveStrings(transforms.DataTransformFn):
def __call__(self, x: dict) -> dict:
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
def create_torch_dataloader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
model_config: _model.BaseModelConfig,
num_workers: int,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
dataset = _mixture_dataset.TransformedDataset(
dataset,
[
*data_config[0].repack_transforms.inputs,
*data_config[0].data_transforms.inputs,
RemoveStrings(),
],
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
def main(dataset_path, task_category, robot_name, task_name, collect_name, save_path):
if robot_name == "lift2" or robot_name == "split_aloha":
config = TrainConfig(
name="lift2",
model=pi0_config.Pi0Config(),
data=[
MultiSimSplitAlohaDataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=True,
gripper_aug_config={
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
"gripper_dim": -1,
"gripper_threshold_method": "std_multiplier",
"gripper_threshold_multiplier": 1.0,
"gripper_min_threshold": 0.001,
"gripper_max_threshold": 1.0,
},
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"left_joint": "states.left_joint.position",
"right_joint": "states.right_joint.position",
"left_gripper": "states.left_gripper.position",
"right_gripper": "states.right_gripper.position"
},
"action_dict": {
"left_joint": "actions.left_joint.position",
"right_joint": "actions.right_joint.position",
"left_gripper": "actions.left_gripper.position",
"right_gripper": "actions.right_gripper.position",
"left_gripper_openness": "master_actions.left_gripper.openness",
"right_gripper_openness": "master_actions.right_gripper.openness"
},
"prompt": "task"
}
)
]
)
),
],
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
elif robot_name == "genie1":
config = TrainConfig(
name="genie1",
model=pi0_config.Pi0Config(),
data=[
MultiSimGenieDataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=True,
gripper_aug_config={
"gripper_action_keys": ["master_actions.left_gripper.openness", "master_actions.right_gripper.openness"],
"gripper_dim": -1,
"gripper_threshold_method": "std_multiplier",
"gripper_threshold_multiplier": 1.0,
"gripper_min_threshold": 0.001,
"gripper_max_threshold": 1.0,
},
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"left_joint": "states.left_joint.position",
"right_joint": "states.right_joint.position",
"left_gripper": "states.left_gripper.position",
"right_gripper": "states.right_gripper.position"
},
"action_dict": {
"left_joint": "actions.left_joint.position",
"right_joint": "actions.right_joint.position",
"left_gripper": "actions.left_gripper.position",
"right_gripper": "actions.right_gripper.position",
"left_gripper_openness": "master_actions.left_gripper.openness",
"right_gripper_openness": "master_actions.right_gripper.openness"
},
"prompt": "task"
}
)
]
)
),
],
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
elif "franka" in robot_name:
config = TrainConfig(
name="franka",
model=pi0_config.Pi0Config(),
data=[
MultiSimFrankaDataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=True,
gripper_aug_config={
"gripper_action_keys": ["actions.gripper.openness"],
"gripper_dim": -1,
"gripper_threshold_method": "std_multiplier",
"gripper_threshold_multiplier": 1.0,
"gripper_min_threshold": 0.001,
"gripper_max_threshold": 1.0,
},
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"joint_position": "states.joint.position",
"gripper_pose": "states.gripper.pose",
"gripper_position": "states.gripper.position",
},
"action_dict": {
"gripper_pose": "actions.gripper.pose",
"gripper_position": "actions.gripper.position",
"gripper_openness": "actions.gripper.openness",
},
"prompt": "task"
}
)
]
)
),
],
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
data_config = config.data[0].create(config.model)
print("done")
output_path = os.path.join(save_path, task_category, robot_name, task_name, collect_name)
stats_json_path = os.path.join(output_path, "norm_stats.json")
if os.path.isfile(stats_json_path):
with open(stats_json_path, 'r', encoding='utf-8') as f:
json.load(f)
return True
data_loader, num_batches = create_torch_dataloader(
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
step_id = 0
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
step_id += 1
for key in keys:
stats[key].update(np.asarray(batch[key]))
if step_id > 10000:
break
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
def check_lerobot_repo(repo_dir: str):
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
print(repo_dir, "true")
return True
else:
print(repo_dir, "false")
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--root_data_dir", type=str, default="data/InternData-A1/sim")
parser.add_argument("--task_category", type=str, default="pick_and_place_tasks")
parser.add_argument("--save_path", type=str, default="stats/sim")
parser.add_argument("--start_ratio", type=float, default=0.0)
parser.add_argument("--end_ratio", type=float, default=1)
args, unknown = parser.parse_known_args()
root_data_dir = os.path.join(args.root_data_dir, args.task_category)
dataset_paths = glob.glob(os.path.join(root_data_dir, "*", "*"))
dataset_paths.sort()
valid_paths = [
p for p in dataset_paths
if check_lerobot_repo(p)
]
start_idx = int(len(valid_paths) * args.start_ratio)
end_idx = int(len(valid_paths) * args.end_ratio) + 1
valid_paths = valid_paths[start_idx:end_idx]
for dataset_path in tqdm.tqdm(valid_paths):
task_category = dataset_path.split('/')[-3]
robot_name = dataset_path.split('/')[-2]
task_name = dataset_path.split('/')[-1]
collect_name = ""
try:
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
except:
print(dataset_path)
dataset_paths_w_subtask = glob.glob(os.path.join(root_data_dir, "*", "*","*"))
dataset_paths_w_subtask.sort()
valid_paths_w_subtask = [
p for p in dataset_paths_w_subtask
if check_lerobot_repo(p)
]
start_idx = int(len(valid_paths_w_subtask) * args.start_ratio)
end_idx = int(len(valid_paths_w_subtask) * args.end_ratio) + 1
valid_paths_w_subtask = valid_paths_w_subtask[start_idx:end_idx]
for dataset_path in tqdm.tqdm(valid_paths_w_subtask):
task_category = dataset_path.split('/')[-4]
robot_name = dataset_path.split('/')[-3]
task_name = dataset_path.split('/')[-2]
collect_name = dataset_path.split('/')[-1]
try:
main(dataset_path, task_category, robot_name, task_name, collect_name, args.save_path)
except:
print(dataset_path)

View File

@@ -0,0 +1,181 @@
"""Compute normalization statistics for real-world tasks.
This script is used to compute the normalization statistics for a given real-world task. It
will compute the mean and standard deviation of the data in the dataset and save it
to the config directory.
"""
import os
import glob
import numpy as np
import tqdm
import tyro
import json
import openpi.models.model as _model
import openpi.shared.normalize as normalize
import openpi.training.config as _config
import openpi.training.mixture_dataset as _mixture_dataset
import openpi.training.data_loader as _data_loader
import openpi.transforms as transforms
### training config ###
import openpi.training.weight_loaders as weight_loaders
import openpi.models.pi0_config as pi0_config
from openpi.training.config import MultiSim2RealSplitAlohaDataConfig, MultiDataConfig, DataConfig, TrainConfig
from pdb import set_trace
class RemoveStrings(transforms.DataTransformFn):
def __call__(self, x: dict) -> dict:
return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)}
def create_torch_dataloader(
data_config: _config.DataConfig,
action_horizon: int,
batch_size: int,
model_config: _model.BaseModelConfig,
num_workers: int,
max_frames: int | None = None,
) -> tuple[_data_loader.Dataset, int]:
dataset = _mixture_dataset.create_mixture_dataset_calculate_norm_stats(data_config, action_horizon, model_config)
dataset = _mixture_dataset.TransformedDataset(
dataset,
[
*data_config[0].repack_transforms.inputs,
*data_config[0].data_transforms.inputs,
RemoveStrings(),
],
)
if max_frames is not None and max_frames < len(dataset):
num_batches = max_frames // batch_size
shuffle = True
else:
num_batches = len(dataset) // batch_size
shuffle = False
data_loader = _data_loader.TorchDataLoader(
dataset,
local_batch_size=batch_size,
num_workers=num_workers,
shuffle=shuffle,
num_batches=num_batches,
)
return data_loader, num_batches
def main(dataset_path, robot_name, task_name, save_path):
if robot_name == "lift2":
config = TrainConfig(
name="lift2",
model=pi0_config.Pi0Config(),
data=[
MultiSim2RealSplitAlohaDataConfig(
repo_dir=dataset_path,
task_id=None,
use_gripper_aug=False,
stats_dir='',
base_config=MultiDataConfig(
prompt_from_task=True,
),
asset_id=robot_name,
robot_name=robot_name,
repack_transforms=transforms.Group(
inputs=[
transforms.RepackTransform(
{
"state_dict": {
"left_joint": "states.left_joint.position",
"right_joint": "states.right_joint.position",
"left_gripper": "states.left_gripper.position",
"right_gripper": "states.right_gripper.position"
},
"action_dict": {
"left_joint": "actions.left_joint.position",
"right_joint": "actions.right_joint.position",
"left_gripper": "actions.left_gripper.position",
"right_gripper": "actions.right_gripper.position",
"left_gripper_openness": "master_actions.left_gripper.openness",
"right_gripper_openness": "master_actions.right_gripper.openness"
},
"prompt": "task"
}
)
]
)
),
],
# pretrain model path
weight_loader=weight_loaders.CheckpointWeightLoader("checkpoints/jax/pi0_base/params"),
pytorch_weight_path="checkpoints/pytorch/pi0_base",
num_train_steps=30_000,
num_workers=4,
fsdp_devices=4,
batch_size=8,
)
data_config = config.data[0].create(config.model)
print("done")
output_path = os.path.join(save_path, robot_name, task_name)
stats_json_path = os.path.join(output_path, "norm_stats.json")
if os.path.isfile(stats_json_path):
with open(stats_json_path, 'r', encoding='utf-8') as f:
json.load(f)
return True
data_loader, num_batches = create_torch_dataloader(
data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames=None
)
keys = ["state", "actions"]
stats = {key: normalize.RunningStats() for key in keys}
step_id = 0
for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"):
step_id += 1
for key in keys:
stats[key].update(np.asarray(batch[key]))
if step_id > 10000:
break
norm_stats = {key: stats.get_statistics() for key, stats in stats.items()}
print(f"Writing stats to: {output_path}")
normalize.save(output_path, norm_stats)
def check_lerobot_repo(repo_dir: str):
if os.path.isdir(os.path.join(repo_dir, "data")) and os.path.isdir(os.path.join(repo_dir, "meta")) and os.path.isdir(os.path.join(repo_dir, "videos")):
print(repo_dir, "true")
return True
else:
print(repo_dir, "false")
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--task_path", type=str, default="data/InternData-A1/sim/long_horizon_tasks/lift2/sort_the_rubbish/*")
parser.add_argument("--robot_name", type=str, default="lift2")
parser.add_argument("--save_path", type=str, default="stats/sim2real")
args, unknown = parser.parse_known_args()
dataset_path=args.task_path
save_path = args.save_path
parts = dataset_path.split("/")
robot_idx = next((i for i, p in enumerate(parts) if p == args.robot_name), None)
if robot_idx is None:
raise ValueError(
f"Cannot find robot name in path. Expected {args.robot_name}, "
f"but got path: {dataset_path}"
)
if robot_idx + 1 >= len(parts):
raise ValueError(
f"Path ends at robot name '{parts[robot_idx]}', cannot determine task_name: {local_path}"
)
robot_name = parts[robot_idx]
task_name = parts[robot_idx + 1]
try:
main(dataset_path, robot_name, task_name, save_path)
except:
print(dataset_path)

View File

@@ -0,0 +1,29 @@
# Run with:
# docker compose -f scripts/docker/compose.yml up --build
services:
openpi_server:
image: openpi_server
build:
context: ../..
dockerfile: scripts/docker/serve_policy.Dockerfile
init: true
tty: true
network_mode: host
# Populate configured openpi data home to /openpi_assets inside the container.
# Populate aws credential inside the container.
volumes:
- $PWD:/app
- ${OPENPI_DATA_HOME:-~/.cache/openpi}:/openpi_assets
environment:
- SERVER_ARGS
- OPENPI_DATA_HOME=/openpi_assets
- IS_DOCKER=true
# Comment out this block if not running on a machine with GPUs.
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]

View File

@@ -0,0 +1,37 @@
#!/bin/bash
# Add Docker's official GPG key:
sudo apt-get update
sudo apt-get install -y ca-certificates curl
sudo install -m 0755 -d /etc/apt/keyrings
sudo curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
sudo chmod a+r /etc/apt/keyrings/docker.asc
# Add the repository to Apt sources:
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" |
sudo tee /etc/apt/sources.list.d/docker.list >/dev/null
sudo apt-get update
sudo apt-get install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
# Add current user to the 'docker' group, which allows them to use docker commands (docker build, docker run, etc).
# See https://docs.docker.com/engine/install/linux-postinstall/
username=$(whoami)
sudo usermod -aG docker $username
# Configure docker to start automatically on system boot.
sudo systemctl enable docker.service
sudo systemctl enable containerd.service
# https://forums.docker.com/t/docker-credential-desktop-exe-executable-file-not-found-in-path-using-wsl2/100225/5
if [ ~/.docker/config.json ]; then
sed -i 's/credsStore/credStore/g' ~/.docker/config.json
fi
echo ""
echo "********************************************************************"
echo "**** Restart to allow Docker permission changes to take effect. ****"
echo "********************************************************************"
echo ""

View File

@@ -0,0 +1,17 @@
#!/bin/bash
# Installs the NVIDIA Container Toolkit, which allows Docker containers to access NVIDIA GPUs.
# NVIDIA's official documentation: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html
curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg &&
curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list |
sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' |
sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list
# NVIDIA's documenation omits 'sudo' in the following command, but it is required.
sudo sed -i -e '/experimental/ s/^#//g' /etc/apt/sources.list.d/nvidia-container-toolkit.list
sudo apt-get update
sudo apt-get install -y nvidia-container-toolkit
sudo nvidia-ctk runtime configure --runtime=docker
sudo systemctl restart docker

View File

@@ -0,0 +1,38 @@
# Dockerfile for serving a PI policy.
# Based on UV's instructions: https://docs.astral.sh/uv/guides/integration/docker/#developing-in-a-container
# Build the container:
# docker build . -t openpi_server -f scripts/docker/serve_policy.Dockerfile
# Run the container:
# docker run --rm -it --network=host -v .:/app --gpus=all openpi_server /bin/bash
FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04@sha256:2d913b09e6be8387e1a10976933642c73c840c0b735f0bf3c28d97fc9bc422e0
COPY --from=ghcr.io/astral-sh/uv:0.5.1 /uv /uvx /bin/
WORKDIR /app
# Needed because LeRobot uses git-lfs.
RUN apt-get update && apt-get install -y git git-lfs linux-headers-generic build-essential clang
# Copy from the cache instead of linking since it's a mounted volume
ENV UV_LINK_MODE=copy
# Write the virtual environment outside of the project directory so it doesn't
# leak out of the container when we mount the application code.
ENV UV_PROJECT_ENVIRONMENT=/.venv
# Install the project's dependencies using the lockfile and settings
RUN uv venv --python 3.11.9 $UV_PROJECT_ENVIRONMENT
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=uv.lock,target=uv.lock \
--mount=type=bind,source=pyproject.toml,target=pyproject.toml \
--mount=type=bind,source=packages/openpi-client/pyproject.toml,target=packages/openpi-client/pyproject.toml \
--mount=type=bind,source=packages/openpi-client/src,target=packages/openpi-client/src \
GIT_LFS_SKIP_SMUDGE=1 uv sync --frozen --no-install-project --no-dev
# Copy transformers_replace files while preserving directory structure
COPY src/openpi/models_pytorch/transformers_replace/ /tmp/transformers_replace/
RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | xargs dirname | xargs -I{} cp -r /tmp/transformers_replace/* {} && rm -rf /tmp/transformers_replace
CMD /bin/bash -c "uv run scripts/serve_policy.py $SERVER_ARGS"

View File

@@ -0,0 +1,27 @@
import os
from pathlib import Path
def download_from_gcs(gcs_uri: str, local_path: str):
local_path = Path(local_path)
local_path.parent.mkdir(parents=True, exist_ok=True)
if os.system("which gsutil > /dev/null 2>&1") == 0:
cmd = f"gsutil cp {gcs_uri} {local_path}"
else:
gcs_http = gcs_uri.replace("gs://", "https://storage.googleapis.com/")
cmd = f"wget -O {local_path} {gcs_http}"
print(f"⬇️ Executing: {cmd}")
ret = os.system(cmd)
if ret == 0:
print("✅ Download complete:", local_path)
else:
raise RuntimeError(f"Download failed: {gcs_uri}")
return local_path
if __name__ == "__main__":
gcs_uri = "gs://vertex-model-garden-paligemma-us/paligemma/pt_224.npz"
save_path = "checkpoints/jax/paligemma/pt_224.npz"
download_from_gcs(gcs_uri, save_path)

View File

@@ -0,0 +1,122 @@
import dataclasses
import enum
import logging
import socket
import tyro
from openpi.policies import policy as _policy
from openpi.policies import policy_config as _policy_config
from openpi.serving import websocket_policy_server
from openpi.training import config as _config
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
@dataclasses.dataclass
class Checkpoint:
"""Load a policy from a trained checkpoint."""
# Training config name (e.g., "pi0_aloha_sim").
config: str
# Checkpoint directory (e.g., "checkpoints/pi0_aloha_sim/exp/10000").
dir: str
@dataclasses.dataclass
class Default:
"""Use the default policy for the given environment."""
@dataclasses.dataclass
class Args:
"""Arguments for the serve_policy script."""
# Environment to serve the policy for. This is only used when serving default policies.
env: EnvMode = EnvMode.ALOHA_SIM
# If provided, will be used in case the "prompt" key is not present in the data, or if the model doesn't have a default
# prompt.
default_prompt: str | None = None
# Port to serve the policy on.
port: int = 8000
# Record the policy's behavior for debugging.
record: bool = False
# Specifies how to load the policy. If not provided, the default policy for the environment will be used.
policy: Checkpoint | Default = dataclasses.field(default_factory=Default)
# Default checkpoints that should be used for each environment.
DEFAULT_CHECKPOINT: dict[EnvMode, Checkpoint] = {
EnvMode.ALOHA: Checkpoint(
config="pi05_aloha",
dir="gs://openpi-assets/checkpoints/pi05_base",
),
EnvMode.ALOHA_SIM: Checkpoint(
config="pi0_aloha_sim",
dir="gs://openpi-assets/checkpoints/pi0_aloha_sim",
),
EnvMode.DROID: Checkpoint(
config="pi05_droid",
dir="gs://openpi-assets/checkpoints/pi05_droid",
),
EnvMode.LIBERO: Checkpoint(
config="pi05_libero",
dir="gs://openpi-assets/checkpoints/pi05_libero",
),
}
def create_default_policy(env: EnvMode, *, default_prompt: str | None = None) -> _policy.Policy:
"""Create a default policy for the given environment."""
if checkpoint := DEFAULT_CHECKPOINT.get(env):
return _policy_config.create_trained_policy(
_config.get_config(checkpoint.config), checkpoint.dir, default_prompt=default_prompt
)
raise ValueError(f"Unsupported environment mode: {env}")
def create_policy(args: Args) -> _policy.Policy:
"""Create a policy from the given arguments."""
match args.policy:
case Checkpoint():
return _policy_config.create_trained_policy(
_config.get_config(args.policy.config), args.policy.dir, default_prompt=args.default_prompt
)
case Default():
return create_default_policy(args.env, default_prompt=args.default_prompt)
def main(args: Args) -> None:
policy = create_policy(args)
policy_metadata = policy.metadata
# Record the policy's behavior.
if args.record:
policy = _policy.PolicyRecorder(policy, "policy_records")
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
logging.info("Creating server (host: %s, ip: %s)", hostname, local_ip)
server = websocket_policy_server.WebsocketPolicyServer(
policy=policy,
host="0.0.0.0",
port=args.port,
metadata=policy_metadata,
)
server.serve_forever()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, force=True)
main(tyro.cli(Args))

View File

@@ -0,0 +1,290 @@
import dataclasses
import functools
import logging
import platform
from typing import Any
import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
from memory_profiler import profile
import psutil
from openpi.shared.online_compute_norm_stats import compute_norm_stats
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
"""Loads and validates the weights. Returns a loaded subset of the weights."""
loaded_params = loader.load(params_shape)
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
return traverse_util.unflatten_dict(
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
)
@at.typecheck
def init_train_state(
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
# initialize the model (and its parameters).
model = config.model.create(model_rng)
# Merge the partial params into the model.
if partial_params is not None:
graphdef, state = nnx.split(model)
# This will produce an error if the partial params are not a subset of the state.
state.replace_by_pure_dict(partial_params)
model = nnx.merge(graphdef, state)
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
return training_utils.TrainState(
step=0,
params=params,
model_def=nnx.graphdef(model),
tx=tx,
opt_state=tx.init(params.filter(config.trainable_filter)),
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
train_state_shape = jax.eval_shape(init, init_rng)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# Initialize the train state and mix in the partial params.
train_state = jax.jit(
init,
donate_argnums=(1,), # donate the partial params buffer.
in_shardings=replicated_sharding,
out_shardings=state_sharding,
)(init_rng, partial_params)
return train_state, state_sharding
@at.typecheck
def train_step(
config: _config.TrainConfig,
rng: at.KeyArrayLike,
state: training_utils.TrainState,
batch: tuple[_model.Observation, _model.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
model = nnx.merge(state.model_def, state.params)
model.train()
@at.typecheck
def loss_fn(
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
):
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
return jnp.mean(chunked_loss)
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
# Filter out frozen params.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
# Update the model in place and return the new full state.
nnx.update(model, new_params)
new_params = nnx.state(model)
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = dataclasses.replace(
new_state,
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
),
)
# Filter out params that aren't kernels.
kernel_params = nnx.state(
model,
nnx.All(
nnx.Param,
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
lambda _, x: x.value.ndim > 1,
),
)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads),
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_period=config.keep_period,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
if config.online_compute_norm_stats:
global_norm_stats = compute_norm_stats(config.name)
else:
global_norm_stats = None
data_loader = _data_loader.create_data_loader_multi(
config,
sharding=data_sharding,
shuffle=True,
global_norm_stats=global_norm_stats,
)
# @profile
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
print(psutil.Process().memory_info().rss/1024**2)
# set_trace()
# Log images from first batch to sanity check.
images_to_log = [
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
for i in range(min(5, len(next(iter(batch[0].images.values())))))
]
wandb.log({"camera_views": images_to_log}, step=0)
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
functools.partial(train_step, config),
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
for step in pbar:
with sharding.set_mesh(mesh):
train_state, info = ptrain_step(train_rng, train_state, batch)
infos.append(info)
if step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())

View File

@@ -0,0 +1,341 @@
"""
Multi-host training entrypoint (JAX).
How to run multi-host (example: 2 nodes):
# node0
export JAX_COORDINATOR_ADDRESS=node0:12345
export JAX_PROCESS_COUNT=2
export JAX_PROCESS_INDEX=0
uv run python scripts/train.py <config_name> --exp_name <exp>
# node1
export JAX_COORDINATOR_ADDRESS=node0:12345
export JAX_PROCESS_COUNT=2
export JAX_PROCESS_INDEX=1
uv run python scripts/train.py <config_name> --exp_name <exp>
Notes:
- Initialize distributed BEFORE any device query.
- Only process_index==0 performs side-effects (wandb, checkpoints, progress bar).
- Total devices across hosts must be divisible by config.fsdp_devices.
"""
import dataclasses
import functools
import logging
import platform
import os
from typing import Any
import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import wandb
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
from pdb import set_trace
def maybe_initialize_distributed() -> bool:
coordinator = os.environ.get("JAX_COORDINATOR_ADDRESS")
process_count = int(os.environ.get("JAX_PROCESS_COUNT", "1"))
process_index = int(os.environ.get("JAX_PROCESS_INDEX", "0"))
if process_count > 1 and coordinator:
jax.distributed.initialize(
coordinator_address=coordinator,
num_processes=process_count,
process_id=process_index,
)
return True
return False
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
else:
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code:
wandb.run.log_code(epath.Path(__file__).parent.parent)
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
"""Loads and validates the weights. Returns a loaded subset of the weights."""
loaded_params = loader.load(params_shape)
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
return traverse_util.unflatten_dict(
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
)
@at.typecheck
def init_train_state(
config: _config.TrainConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, *, resume: bool
) -> tuple[training_utils.TrainState, Any]:
tx = _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None)
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
# initialize the model (and its parameters).
model = config.model.create(model_rng)
# Merge the partial params into the model.
if partial_params is not None:
graphdef, state = nnx.split(model)
# This will produce an error if the partial params are not a subset of the state.
state.replace_by_pure_dict(partial_params)
model = nnx.merge(graphdef, state)
params = nnx.state(model)
# Convert frozen params to bfloat16.
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
return training_utils.TrainState(
step=0,
params=params,
model_def=nnx.graphdef(model),
tx=tx,
opt_state=tx.init(params.filter(config.trainable_filter)),
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
train_state_shape = jax.eval_shape(init, init_rng)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
# Initialize the train state and mix in the partial params.
train_state = jax.jit(
init,
donate_argnums=(1,), # donate the partial params buffer.
in_shardings=replicated_sharding,
out_shardings=state_sharding,
)(init_rng, partial_params)
return train_state, state_sharding
@at.typecheck
def train_step(
config: _config.TrainConfig,
rng: at.KeyArrayLike,
state: training_utils.TrainState,
batch: tuple[_model.Observation, _model.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
model = nnx.merge(state.model_def, state.params)
model.train()
@at.typecheck
def loss_fn(
model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions
):
chunked_loss = model.compute_loss(rng, observation, actions, train=True)
return jnp.mean(chunked_loss)
# set_trace()
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
# Filter out frozen params.
diff_state = nnx.DiffState(0, config.trainable_filter)
loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)(model, train_rng, observation, actions)
params = state.params.filter(config.trainable_filter)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
# Update the model in place and return the new full state.
nnx.update(model, new_params)
new_params = nnx.state(model)
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None:
new_state = dataclasses.replace(
new_state,
ema_params=jax.tree.map(
lambda old, new: state.ema_decay * old + (1 - state.ema_decay) * new, state.ema_params, new_params
),
)
# Filter out params that aren't kernels.
kernel_params = nnx.state(
model,
nnx.All(
nnx.Param,
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
lambda _, x: x.value.ndim > 1,
),
)
info = {
"loss": loss,
"grad_norm": optax.global_norm(grads),
"param_norm": optax.global_norm(kernel_params),
}
return new_state, info
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
# Initialize multi-host distributed if environment variables are set
distributed_initialized = maybe_initialize_distributed()
is_main = jax.process_index() == 0
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
config.checkpoint_dir,
keep_period=config.keep_period,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, resuming=resuming, enabled=(config.wandb_enabled and is_main))
data_loader = _data_loader.create_data_loader_multi(
config,
sharding=data_sharding,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
# Note: Wandb image logging is disabled in multi-node setup to avoid potential hanging issues
# caused by concurrent access to sharded arrays across processes.
train_state, train_state_sharding = init_train_state(config, init_rng, mesh, resume=resuming)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
functools.partial(train_step, config),
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(train_state.step)
step_iter = range(start_step, config.num_train_steps)
pbar = (
tqdm.tqdm(
step_iter,
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
if is_main
else None
)
infos = []
for step in step_iter:
with sharding.set_mesh(mesh):
train_state, info = ptrain_step(train_rng, train_state, batch)
if is_main and pbar is not None:
pbar.update(1)
infos.append(info)
if step % config.log_interval == 0:
# print("log!")
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
if is_main:
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
if pbar is not None:
pbar.write(f"Step {step}: {info_str}")
else:
logging.info(f"Step {step}: {info_str}")
if config.wandb_enabled:
wandb.log(reduced_info, step=step)
infos = []
batch = next(data_iter)
if ((step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1):
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
if is_main:
if pbar is not None:
pbar.close()
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if distributed_initialized:
jax.distributed.shutdown()
if __name__ == "__main__":
main(_config.cli())

View File

@@ -0,0 +1,632 @@
"""
PyTorch training entrypoint for PI0/PI05 with multi-GPU and multi-node (DDP) support.
This script mirrors the behavior of the JAX trainer (`scripts/train.py`) but runs
entirely in PyTorch using the `PI0Pytorch` model and your existing config/data
pipeline from `src/openpi/training/config.py` and `src/openpi/training/data_loader.py`.
Usage
Single GPU:
python scripts/train_pytorch.py <config_name> --exp_name <run_name> --save_interval <interval>
Example:
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test
python scripts/train_pytorch.py debug --exp_name pytorch_ddp_test --resume # Resume from latest checkpoint
Multi-GPU (single node):
torchrun --standalone --nnodes=1 --nproc_per_node=<num_gpus> scripts/train_pytorch.py <config_name> --exp_name <run_name>
Example:
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test
torchrun --standalone --nnodes=1 --nproc_per_node=2 scripts/train_pytorch.py pi0_aloha_sim --exp_name pytorch_ddp_test --resume
Multi-Node Training:
torchrun \
--nnodes=<num_nodes> --nproc_per_node=<gpus_per_node> --node_rank=<rank_of_node> \
--master_addr=<master_ip> --master_port=<port> \
scripts/train_pytorch.py <config_name> --exp_name=<run_name> --save_interval <interval>
"""
import dataclasses
import gc
import logging
import os
import platform
import shutil
import time
import jax
import numpy as np
import safetensors.torch
import torch
import torch.distributed as dist
import torch.nn.parallel
import tqdm
import wandb
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
import openpi.shared.normalize as _normalize
import openpi.training.config as _config
import openpi.training.data_loader as _data
def init_logging():
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if not logger.handlers:
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
else:
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, *, resuming: bool, enabled: bool = True):
"""Initialize wandb logging."""
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
wandb.init(
name=config.exp_name,
config=dataclasses.asdict(config),
project=config.project_name,
)
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
def setup_ddp():
world_size = int(os.environ.get("WORLD_SIZE", "1"))
use_ddp = world_size > 1
if use_ddp and not torch.distributed.is_initialized():
backend = "nccl" if torch.cuda.is_available() else "gloo"
torch.distributed.init_process_group(backend=backend, init_method="env://")
# Set up debugging environment variables for DDP issues
if os.environ.get("TORCH_DISTRIBUTED_DEBUG") is None:
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "INFO"
local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", "0")))
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.cuda.set_device(device)
return use_ddp, local_rank, device
def cleanup_ddp():
if torch.distributed.is_initialized():
torch.distributed.barrier()
torch.distributed.destroy_process_group()
def set_seed(seed: int, local_rank: int):
torch.manual_seed(seed + local_rank)
np.random.seed(seed + local_rank)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed + local_rank)
def build_datasets(config: _config.TrainConfig):
# Use the unified data loader with PyTorch framework
data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=True)
return data_loader, data_loader.data_config()
def get_model_state_dict(model):
"""Get state dict from model, handling DDP wrapper."""
return (
model.module.state_dict()
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model.state_dict()
)
def get_model_parameters(model):
"""Get parameters from model, handling DDP wrapper."""
return (
model.module.parameters()
if isinstance(model, torch.nn.parallel.DistributedDataParallel)
else model.parameters()
)
def save_checkpoint(model, optimizer, global_step, config, is_main, data_config):
"""Save a checkpoint with model state, optimizer state, and metadata."""
if not is_main:
return
# Only save if it's time to save or if it's the final step
if (global_step % config.save_interval == 0 and global_step > 0) or global_step == config.num_train_steps - 1:
# Create temporary directory for atomic checkpoint saving
final_ckpt_dir = config.checkpoint_dir / f"{global_step}"
tmp_ckpt_dir = config.checkpoint_dir / f"tmp_{global_step}"
# Remove any existing temp directory and create new one
if tmp_ckpt_dir.exists():
shutil.rmtree(tmp_ckpt_dir)
tmp_ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save model state using safetensors (handle shared tensors)
model_to_save = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
safetensors.torch.save_model(model_to_save, tmp_ckpt_dir / "model.safetensors")
# Save optimizer state using PyTorch format
torch.save(optimizer.state_dict(), tmp_ckpt_dir / "optimizer.pt")
# Save training metadata (avoid saving full config to prevent JAX/Flax compatibility issues)
metadata = {
"global_step": global_step,
"config": dataclasses.asdict(config),
"timestamp": time.time(),
}
torch.save(metadata, tmp_ckpt_dir / "metadata.pt")
# save norm stats
norm_stats = data_config.norm_stats
if norm_stats is not None and data_config.asset_id is not None:
_normalize.save(tmp_ckpt_dir / "assets" / data_config.asset_id, norm_stats)
# Atomically move temp directory to final location
if final_ckpt_dir.exists():
shutil.rmtree(final_ckpt_dir)
tmp_ckpt_dir.rename(final_ckpt_dir)
logging.info(f"Saved checkpoint at step {global_step} -> {final_ckpt_dir}")
# Log checkpoint to wandb
if config.wandb_enabled:
wandb.log({"checkpoint_step": global_step}, step=global_step)
def load_checkpoint(model, optimizer, checkpoint_dir, device):
"""Load the latest checkpoint and return the global step."""
checkpoint_steps = [
int(d.name)
for d in checkpoint_dir.iterdir()
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
]
if not checkpoint_steps:
raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")
latest_step = max(checkpoint_steps)
ckpt_dir = checkpoint_dir / f"{latest_step}"
# Clear memory before loading checkpoints
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "before_loading_checkpoint")
try:
# Load model state with error handling
logging.info("Loading model state...")
safetensors_path = ckpt_dir / "model.safetensors"
if safetensors_path.exists():
model_to_load = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model
safetensors.torch.load_model(model_to_load, safetensors_path, device=str(device))
logging.info("Loaded model state from safetensors format")
else:
raise FileNotFoundError(f"No model checkpoint found at {ckpt_dir}")
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_model")
# Load optimizer state with error handling
logging.info("Loading optimizer state...")
optimizer_path = ckpt_dir / "optimizer.pt"
if optimizer_path.exists():
optimizer_state_dict = torch.load(optimizer_path, map_location=device, weights_only=False)
logging.info("Loaded optimizer state from pt format")
else:
raise FileNotFoundError(f"No optimizer checkpoint found at {ckpt_dir}")
optimizer.load_state_dict(optimizer_state_dict)
del optimizer_state_dict
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_optimizer")
# Load metadata
logging.info("Loading metadata...")
metadata = torch.load(ckpt_dir / "metadata.pt", map_location=device, weights_only=False)
global_step = metadata.get("global_step", latest_step)
del metadata
torch.cuda.empty_cache()
gc.collect()
log_memory_usage(device, latest_step, "after_loading_metadata")
logging.info(f"Successfully loaded all checkpoint components from step {latest_step}")
return global_step
except RuntimeError as e:
if "out of memory" in str(e):
# Clear memory and provide detailed error message
torch.cuda.empty_cache()
gc.collect()
logging.error(f"Out of memory error while loading checkpoint: {e!s}")
log_memory_usage(device, latest_step, "after_oom_error")
raise RuntimeError(
"Out of memory while loading checkpoint. Try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True"
) from e
raise
def get_latest_checkpoint_step(checkpoint_dir):
"""Get the latest checkpoint step number from a checkpoint directory."""
checkpoint_steps = [
int(d.name)
for d in checkpoint_dir.iterdir()
if d.is_dir() and d.name.isdigit() and not d.name.startswith("tmp_")
]
return max(checkpoint_steps) if checkpoint_steps else None
def log_memory_usage(device, step, phase="unknown"):
"""Log detailed memory usage information."""
if not torch.cuda.is_available():
return
memory_allocated = torch.cuda.memory_allocated(device) / 1e9
memory_reserved = torch.cuda.memory_reserved(device) / 1e9
memory_free = torch.cuda.memory_reserved(device) - torch.cuda.memory_allocated(device)
memory_free = memory_free / 1e9
# Get more detailed memory info
memory_stats = torch.cuda.memory_stats(device)
max_memory_allocated = memory_stats.get("allocated_bytes.all.peak", 0) / 1e9
max_memory_reserved = memory_stats.get("reserved_bytes.all.peak", 0) / 1e9
# Get DDP info if available
ddp_info = ""
if dist.is_initialized():
ddp_info = f" | DDP: rank={dist.get_rank()}, world_size={dist.get_world_size()}"
logging.info(
f"Step {step} ({phase}): GPU memory - allocated: {memory_allocated:.2f}GB, reserved: {memory_reserved:.2f}GB, free: {memory_free:.2f}GB, peak_allocated: {max_memory_allocated:.2f}GB, peak_reserved: {max_memory_reserved:.2f}GB{ddp_info}"
)
def train_loop(config: _config.TrainConfig):
use_ddp, local_rank, device = setup_ddp()
is_main = (not use_ddp) or (dist.get_rank() == 0)
set_seed(config.seed, local_rank)
# Initialize checkpoint directory and wandb
resuming = False
if config.resume:
# Find checkpoint directory based on experiment name
exp_checkpoint_dir = config.checkpoint_dir
if exp_checkpoint_dir.exists():
# Use validation to find the latest working checkpoint
latest_step = get_latest_checkpoint_step(exp_checkpoint_dir)
if latest_step is not None:
resuming = True
logging.info(
f"Resuming from experiment checkpoint directory: {exp_checkpoint_dir} at step {latest_step}"
)
else:
raise FileNotFoundError(f"No valid checkpoints found in {exp_checkpoint_dir} for resume")
else:
raise FileNotFoundError(f"Experiment checkpoint directory {exp_checkpoint_dir} does not exist for resume")
elif config.overwrite and config.checkpoint_dir.exists():
shutil.rmtree(config.checkpoint_dir)
logging.info(f"Overwriting checkpoint directory: {config.checkpoint_dir}")
# Create checkpoint directory with experiment name
if not resuming:
# For new runs, create experiment-specific checkpoint directory
exp_checkpoint_dir = config.checkpoint_dir
exp_checkpoint_dir.mkdir(parents=True, exist_ok=True)
logging.info(f"Created experiment checkpoint directory: {exp_checkpoint_dir}")
else:
# For resume, checkpoint_dir is already set to the experiment directory
logging.info(f"Using existing experiment checkpoint directory: {config.checkpoint_dir}")
# Initialize wandb (only on main process)
if is_main:
init_wandb(config, resuming=resuming, enabled=config.wandb_enabled)
# Build data loader using the unified data loader
# Calculate effective batch size per GPU for DDP
# For N GPUs, each GPU should get batch_size/N samples, so total across all GPUs is batch_size
world_size = torch.distributed.get_world_size() if use_ddp else 1
effective_batch_size = config.batch_size // world_size
logging.info(
f"Using batch size per GPU: {effective_batch_size} (total batch size across {world_size} GPUs: {config.batch_size})"
)
# Pass the original batch size to data loader - it will handle DDP splitting internally
loader, data_config = build_datasets(config)
# Log sample images to wandb on first batch
if is_main and config.wandb_enabled and not resuming:
# Create a separate data loader for sample batch to avoid consuming the main loader
sample_data_loader = _data.create_data_loader(config, framework="pytorch", shuffle=False)
sample_batch = next(iter(sample_data_loader))
# Convert observation and actions to torch tensors
observation, actions = sample_batch
sample_batch = observation.to_dict()
sample_batch["actions"] = actions
# Create sample images for wandb
images_to_log = []
# Get batch size from the first image tensor
batch_size = next(iter(sample_batch["image"].values())).shape[0]
for i in range(min(5, batch_size)):
# Concatenate all camera views horizontally for this batch item
# Convert from NCHW to NHWC format for wandb
img_concatenated = torch.cat([img[i].permute(1, 2, 0) for img in sample_batch["image"].values()], axis=1)
img_concatenated = img_concatenated.cpu().numpy()
images_to_log.append(wandb.Image(img_concatenated))
wandb.log({"camera_views": images_to_log}, step=0)
# Clear sample batch from memory aggressively
del sample_batch, observation, actions, images_to_log, img_concatenated
del sample_data_loader # Also delete the sample data loader
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logging.info("Cleared sample batch and data loader from memory")
# Build model
if not isinstance(config.model, openpi.models.pi0_config.Pi0Config):
# Convert dataclass to Pi0Config if needed
model_cfg = openpi.models.pi0_config.Pi0Config(
dtype=config.pytorch_training_precision,
action_dim=config.model.action_dim,
action_horizon=config.model.action_horizon,
max_token_len=config.model.max_token_len,
paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"),
action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"),
pi05=getattr(config.model, "pi05", False),
)
else:
model_cfg = config.model
# Update dtype to match pytorch_training_precision
object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision)
model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg).to(device)
if hasattr(model, "gradient_checkpointing_enable"):
enable_gradient_checkpointing = True
model.gradient_checkpointing_enable()
logging.info("Enabled gradient checkpointing for memory optimization")
else:
enable_gradient_checkpointing = False
logging.info("Gradient checkpointing is not supported for this model")
# Log initial memory usage after model creation
if is_main and torch.cuda.is_available():
log_memory_usage(device, 0, "after_model_creation")
# Enable memory optimizations for large-scale training
if world_size >= 8:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# Set memory allocation configuration
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
logging.info("Enabled memory optimizations for 8+ GPU training")
if use_ddp:
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[device.index] if device.type == "cuda" else None,
find_unused_parameters=False, # Disable for memory efficiency
gradient_as_bucket_view=True, # Enable for memory efficiency
static_graph=world_size >= 8, # Enable for 8+ GPUs
)
# Load weights from weight_loader if specified (for fine-tuning)
# if config.pytorch_weight_path is not None:
# logging.info(f"Loading weights from: {config.pytorch_weight_path}")
# model_path = os.path.join(config.pytorch_weight_path, "model.safetensors")
# safetensors.torch.load_model(
# (model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model), model_path
# )
# logging.info(f"Loaded PyTorch weights from {config.pytorch_weight_path}")
# Optimizer + learning rate schedule from config
warmup_steps = config.lr_schedule.warmup_steps
peak_lr = config.lr_schedule.peak_lr
decay_steps = config.lr_schedule.decay_steps
end_lr = config.lr_schedule.decay_lr
# Create optimizer with config parameters
optim = torch.optim.AdamW(
model.parameters(),
lr=peak_lr,
betas=(config.optimizer.b1, config.optimizer.b2),
eps=config.optimizer.eps,
weight_decay=config.optimizer.weight_decay,
)
# Load checkpoint if resuming
global_step = 0
if resuming:
global_step = load_checkpoint(model, optim, config.checkpoint_dir, device)
logging.info(f"Resumed training from step {global_step}")
def lr_schedule(step: int):
if step < warmup_steps:
# Match JAX behavior: start from peak_lr / (warmup_steps + 1)
init_lr = peak_lr / (warmup_steps + 1)
return init_lr + (peak_lr - init_lr) * step / warmup_steps
# cosine decay
progress = min(1.0, (step - warmup_steps) / max(1, decay_steps - warmup_steps))
cos = 0.5 * (1 + np.cos(np.pi * progress))
return end_lr + (peak_lr - end_lr) * cos
model.train()
start_time = time.time()
infos = [] # Collect stats over log interval
if is_main:
logging.info(
f"Running on: {platform.node()} | world_size={torch.distributed.get_world_size() if use_ddp else 1}"
)
logging.info(
f"Training config: batch_size={config.batch_size}, effective_batch_size={effective_batch_size}, num_train_steps={config.num_train_steps}"
)
logging.info(f"Memory optimizations: gradient_checkpointing={enable_gradient_checkpointing}")
logging.info(
f"LR schedule: warmup={warmup_steps}, peak_lr={peak_lr:.2e}, decay_steps={decay_steps}, end_lr={end_lr:.2e}"
)
logging.info(
f"Optimizer: {type(config.optimizer).__name__}, weight_decay={config.optimizer.weight_decay}, clip_norm={config.optimizer.clip_gradient_norm}"
)
logging.info("EMA is not supported for PyTorch training")
logging.info(f"Training precision: {model_cfg.dtype}")
# Training loop - iterate until we reach num_train_steps
pbar = (
tqdm.tqdm(total=config.num_train_steps, initial=global_step, desc="Training", disable=not is_main)
if is_main
else None
)
while global_step < config.num_train_steps:
# Set epoch for distributed training
if use_ddp and hasattr(loader, "set_epoch"):
loader.set_epoch(global_step // len(loader))
for observation, actions in loader:
# Check if we've reached the target number of steps
if global_step >= config.num_train_steps:
break
# The unified data loader returns (observation, actions) tuple
observation = jax.tree.map(lambda x: x.to(device), observation) # noqa: PLW2901
actions = actions.to(torch.float32) # noqa: PLW2901
actions = actions.to(device) # noqa: PLW2901
# Update LR
for pg in optim.param_groups:
pg["lr"] = lr_schedule(global_step)
# Forward pass
losses = model(observation, actions)
# Ensure losses is a tensor and handle different return types
if isinstance(losses, list | tuple):
losses = torch.stack(losses)
elif not isinstance(losses, torch.Tensor):
losses = torch.tensor(losses, device=device, dtype=torch.float32)
loss = losses.mean()
# Backward pass
loss.backward()
# Log memory usage after backward pass
if global_step < 5 and is_main and torch.cuda.is_available():
log_memory_usage(device, global_step, "after_backward")
# Gradient clipping
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optimizer.clip_gradient_norm)
# Optimizer step
optim.step()
optim.zero_grad(set_to_none=True)
# Clear gradients more aggressively
for param in model.parameters():
if param.grad is not None:
param.grad.detach_()
param.grad = None
# Collect stats
if is_main:
infos.append(
{
"loss": loss.item(),
"learning_rate": optim.param_groups[0]["lr"],
"grad_norm": float(grad_norm) if isinstance(grad_norm, torch.Tensor) else grad_norm,
}
)
if is_main and (global_step % config.log_interval == 0):
elapsed = time.time() - start_time
# Average stats over log interval
avg_loss = sum(info["loss"] for info in infos) / len(infos)
avg_lr = sum(info["learning_rate"] for info in infos) / len(infos)
avg_grad_norm = None
if any("grad_norm" in info for info in infos):
vals = [
info["grad_norm"] for info in infos if "grad_norm" in info and info["grad_norm"] is not None
]
if len(vals) > 0:
avg_grad_norm = sum(vals) / len(vals)
logging.info(
f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} grad_norm={avg_grad_norm:.2f} time={elapsed:.1f}s"
if avg_grad_norm is not None
else f"step={global_step} loss={avg_loss:.4f} lr={avg_lr:.2e} time={elapsed:.1f}s"
)
# Log to wandb
if config.wandb_enabled and len(infos) > 0:
log_payload = {
"loss": avg_loss,
"learning_rate": avg_lr,
"step": global_step,
"time_per_step": elapsed / config.log_interval,
}
if avg_grad_norm is not None:
log_payload["grad_norm"] = avg_grad_norm
wandb.log(log_payload, step=global_step)
start_time = time.time()
infos = [] # Reset stats collection
global_step += 1
# Save checkpoint using the new mechanism
save_checkpoint(model, optim, global_step, config, is_main, data_config)
# Update progress bar
if pbar is not None:
pbar.update(1)
pbar.set_postfix(
{"loss": f"{loss.item():.4f}", "lr": f"{optim.param_groups[0]['lr']:.2e}", "step": global_step}
)
# Close progress bar
if pbar is not None:
pbar.close()
# Finish wandb run
if is_main and config.wandb_enabled:
wandb.finish()
cleanup_ddp()
def main():
init_logging()
config = _config.cli()
train_loop(config)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,30 @@
import dataclasses
import os
import pathlib
import pytest
os.environ["JAX_PLATFORMS"] = "cpu"
from openpi.training import config as _config
from . import train
@pytest.mark.parametrize("config_name", ["debug"])
def test_train(tmp_path: pathlib.Path, config_name: str):
config = dataclasses.replace(
_config._CONFIGS_DICT[config_name], # noqa: SLF001
batch_size=2,
checkpoint_base_dir=str(tmp_path / "checkpoint"),
exp_name="test",
overwrite=False,
resume=False,
num_train_steps=2,
log_interval=1,
)
train.main(config)
# test resuming
config = dataclasses.replace(config, resume=True, num_train_steps=4)
train.main(config)

View File

@@ -0,0 +1,209 @@
#!/usr/bin/env bash
set -ex
cd YOUR_PATH/openpi
export USE_TF=0
export USE_TORCH=0
export USE_JAX=1
export IMAGEIO_FFMPEG_EXE=ffmpeg
# JAX GPU memory fraction
export XLA_PYTHON_CLIENT_MEM_FRACTION="${XLA_PYTHON_CLIENT_MEM_FRACTION:-0.9}"
# ============================================================================
# NCCL Configuration
# ============================================================================
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_TIMEOUT=3600
export NCCL_DEBUG="${NCCL_DEBUG:-WARN}"
# ============================================================================
# Platform-Injected Configuration
# ============================================================================
# The platform automatically injects these when DISTRIBUTED_JOB=true:
# - NCCL_IB_HCA, NCCL_IB_GID_INDEX, NCCL_SOCKET_IFNAME
# - NODE_RANK, NODE_COUNT, MASTER_ADDR, PROC_PER_NODE
# - CUDA_VISIBLE_DEVICES
# We trust and use these platform configurations directly.
# ============================================================================
echo ""
echo "=========================================="
echo "Platform Configuration"
echo "=========================================="
echo "NODE_RANK: ${NODE_RANK:-<not set>}"
echo "NODE_COUNT: ${NODE_COUNT:-<not set>}"
echo "MASTER_ADDR: ${MASTER_ADDR:-<not set>}"
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
echo "=========================================="
echo ""
# ============================================================================
# NCCL Transport Configuration
# ============================================================================
# Use platform-injected configuration if available, otherwise fallback
# ============================================================================
if [ -n "${NCCL_IB_HCA:-}" ]; then
# Platform has configured InfiniBand
echo "[NCCL] ✓ Using platform-injected InfiniBand configuration"
# Only set NCCL_NET if not already set
if [ -z "${NCCL_NET:-}" ]; then
export NCCL_NET="IB"
fi
# Set IB timeout if not already set
if [ -z "${NCCL_IB_TIMEOUT:-}" ]; then
export NCCL_IB_TIMEOUT=23
fi
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
echo "[NCCL] NCCL_IB_HCA: ${NCCL_IB_HCA}"
echo "[NCCL] NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX}"
echo "[NCCL] NCCL_IB_TIMEOUT: ${NCCL_IB_TIMEOUT}"
elif [ -n "${NCCL_SOCKET_IFNAME:-}" ]; then
# Platform has configured Socket
echo "[NCCL] ✓ Using platform-injected Socket configuration"
if [ -z "${NCCL_NET:-}" ]; then
export NCCL_NET="Socket"
fi
echo "[NCCL] NCCL_NET: ${NCCL_NET}"
echo "[NCCL] NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME}"
else
# No platform injection - use OPENPI_NCCL_NET preference
echo "[NCCL] ⚠️ No platform-injected NCCL configuration"
if [ "${OPENPI_NCCL_NET:-IB}" = "IB" ]; then
echo "[NCCL] ✗ InfiniBand requested but not configured by platform"
echo "[NCCL] ✗ Falling back to Socket transport"
export NCCL_NET="Socket"
export NCCL_IB_DISABLE=1
else
export NCCL_NET="Socket"
export NCCL_IB_DISABLE=1
echo "[NCCL] Using Socket transport"
fi
fi
echo ""
# ============================================================================
# JAX Distributed Configuration
# ============================================================================
# Map platform variables to JAX variables
# ============================================================================
echo "=========================================="
echo "JAX Distributed Configuration"
echo "=========================================="
JAX_COORDINATOR_PORT="${JAX_COORDINATOR_PORT:-12345}"
# Set JAX coordinator address
if [ -z "${JAX_COORDINATOR_ADDRESS:-}" ] && [ -n "${MASTER_ADDR:-}" ]; then
export JAX_COORDINATOR_ADDRESS="${MASTER_ADDR}:${JAX_COORDINATOR_PORT}"
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS} (from MASTER_ADDR)"
elif [ -n "${JAX_COORDINATOR_ADDRESS:-}" ]; then
echo "[JAX] ✓ Coordinator: ${JAX_COORDINATOR_ADDRESS}"
else
echo "[JAX] ✗ WARNING: No coordinator address set!"
fi
# Set JAX process count
if [ -z "${JAX_PROCESS_COUNT:-}" ] && [ -n "${NODE_COUNT:-}" ]; then
export JAX_PROCESS_COUNT="${NODE_COUNT}"
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT} (from NODE_COUNT)"
elif [ -n "${JAX_PROCESS_COUNT:-}" ]; then
echo "[JAX] ✓ Process count: ${JAX_PROCESS_COUNT}"
fi
# Set JAX process index
if [ -z "${JAX_PROCESS_INDEX:-}" ] && [ -n "${NODE_RANK:-}" ]; then
export JAX_PROCESS_INDEX="${NODE_RANK}"
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX} (from NODE_RANK)"
elif [ -n "${JAX_PROCESS_INDEX:-}" ]; then
echo "[JAX] ✓ Process index: ${JAX_PROCESS_INDEX}"
fi
echo "=========================================="
echo ""
# ============================================================================
# Python Environment
# ============================================================================
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
conda activate pi0
# ============================================================================
# Configuration Summary
# ============================================================================
echo "=========================================="
echo "Configuration Summary"
echo "=========================================="
echo "NCCL_NET: ${NCCL_NET:-<not set>}"
echo "NCCL_IB_HCA: ${NCCL_IB_HCA:-<not set>}"
echo "NCCL_IB_GID_INDEX: ${NCCL_IB_GID_INDEX:-<not set>}"
echo "NCCL_SOCKET_IFNAME: ${NCCL_SOCKET_IFNAME:-<not set>}"
echo "JAX_COORDINATOR: ${JAX_COORDINATOR_ADDRESS:-<not set>}"
echo "JAX_PROCESS_COUNT: ${JAX_PROCESS_COUNT:-<not set>}"
echo "JAX_PROCESS_INDEX: ${JAX_PROCESS_INDEX:-<not set>}"
echo "=========================================="
echo ""
# ============================================================================
# Display Host Information
# ============================================================================
python - <<'EOF'
import socket
import os
import jax
hostname = socket.gethostname()
devices = jax.local_devices()
device_count = len(devices)
device_ids = [d.id for d in devices]
print(f"[JAX] host={hostname}, devices={device_count}xgpu, ids={device_ids}")
print(f"[JAX] JAX_COORDINATOR_ADDRESS={os.environ.get('JAX_COORDINATOR_ADDRESS', '<not set>')}")
print(f"[JAX] JAX_PROCESS_COUNT={os.environ.get('JAX_PROCESS_COUNT', '<not set>')}")
print(f"[JAX] JAX_PROCESS_INDEX={os.environ.get('JAX_PROCESS_INDEX', '<not set>')}")
EOF
# ============================================================================
# Launch Training
# ============================================================================
# Determine experiment name based on transport
if [ "${OPENPI_DEBUG_SINGLE_GPU:-0}" = "1" ]; then
EXP_NAME="${EXP_NAME:-dev_jax_single_gpu}"
echo "[DEBUG] Running in single-GPU mode"
else
EXP_NAME="${EXP_NAME:-dev_jax_multinode_ib}"
fi
echo ""
echo "=========================================="
echo "Starting Training"
echo "=========================================="
echo "Experiment: $EXP_NAME"
echo "=========================================="
echo ""
ulimit -n 1000000
python scripts/train_jax_multinode.py \
pretrain-interndata-a1 \
--exp-name=pretrain-interndata-a1 \
--num_workers=12 \
--fsdp_devices=8 \
--batch_size=512 \
--num_train_steps=2000000 \
--save_interval=5000

View File

@@ -0,0 +1,13 @@
set -ex
export IMAGEIO_FFMPEG_EXE=ffmpeg
export OMP_NUM_THREADS=128
export PYTHONPATH=YOUR_PATH/openpi/src:YOUR_PATH/openpi/packages/openpi-client/src:YOUR_PATH/openpi/third_party/lerobot:${PYTHONPATH}
conda activate pi0
cd YOUR_PATH/openpi
ulimit -n 1000000
config_name=$1
XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 python scripts/train.py ${config_name} \
--exp-name=${config_name}

View File

@@ -0,0 +1,17 @@
import os
import pynvml
import pytest
def set_jax_cpu_backend_if_no_gpu() -> None:
try:
pynvml.nvmlInit()
pynvml.nvmlShutdown()
except pynvml.NVMLError:
# No GPU found.
os.environ["JAX_PLATFORMS"] = "cpu"
def pytest_configure(config: pytest.Config) -> None:
set_jax_cpu_backend_if_no_gpu()

View File

@@ -0,0 +1,459 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma adaptation for Pi, taken from big_vision.
We follow this einsum axis naming convention:
B: batch
T: query length
S: k/v length
N: num query heads
K: num k/v heads
G: num query heads per k/v head
H: head dim
D: d_model ("features")
"""
from collections.abc import Sequence
import dataclasses
from typing import Literal, TypeAlias
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import openpi.models.lora as lora
import openpi.shared.array_typing as at
import openpi.training.sharding as sharding
PALIGEMMA_VOCAB_SIZE = 257_152
@dataclasses.dataclass
class Config:
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict)
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"]
def get_config(variant: Variant) -> Config:
"""Returns config for specified gemma variant."""
if variant == "dummy":
return Config(
width=64,
depth=4,
mlp_dim=128,
num_heads=8,
num_kv_heads=1,
head_dim=16,
)
if variant == "gemma_300m":
# 311M params
return Config(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
)
if variant == "gemma_2b_lora":
return Config(
width=2048,
depth=18,
mlp_dim=16_384,
num_heads=8,
num_kv_heads=1,
head_dim=256,
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)},
)
if variant == "gemma_300m_lora":
# 311M params
return Config(
width=1024,
depth=18,
mlp_dim=4096,
num_heads=8,
num_kv_heads=1,
head_dim=256,
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)},
)
raise ValueError(f"Unknown variant: {variant}")
@at.typecheck
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x, cond):
dtype = x.dtype # original dtype, could be half-precision
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
if cond is None:
# regular RMSNorm
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
normed_inputs = normed_inputs * (
1 + scale
) # scale by learned parameter in float32 (matches Flax implementation)
return normed_inputs.astype(dtype), None # return in original dtype
# adaptive RMSNorm
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond)
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1)
normed_inputs = normed_inputs * (1 + scale) + shift # scale and shift in float32
return normed_inputs.astype(dtype), gate
@at.typecheck
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
"input_embedding",
nn.initializers.normal(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x):
x = self.input_embedding_table[(x,)]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x):
return jnp.dot(x, self.input_embedding_table.T)
@at.typecheck
class Attention(nn.Module):
"""Attention module."""
configs: Sequence[Config]
@nn.compact
def __call__(self, xs, positions, attn_mask, kv_cache):
# all experts must share the same head dim, num heads, and num kv heads for self-attention to work
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs)
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs)
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs)
dtype = next(x.dtype for x in xs if x is not None) # original dtype, could be half-precision
qkvs = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is None:
continue
if config.num_kv_heads == config.num_heads:
qkv_einsum = lora.Einsum(
shape=(3, config.num_heads, config.width, config.head_dim),
name=_name("qkv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=config.lora_configs.get("attn"),
)
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x))
else:
q_einsum = lora.Einsum(
shape=(config.num_heads, config.width, config.head_dim),
name=_name("q_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=config.lora_configs.get("attn"),
)
q = q_einsum("BTD,NDH->BTNH", x)
kv_einsum = lora.Einsum(
shape=(2, config.num_kv_heads, config.width, config.head_dim),
name=_name("kv_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=config.lora_configs.get("attn"),
)
k, v = kv_einsum("BSD,2KDH->2BSKH", x)
qkvs.append((q, k, v))
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True))
q = _apply_rope(q, positions=positions)
q *= self.configs[0].head_dim ** -0.5
k = _apply_rope(k, positions=positions)
# should still be half-precision here (if input was half-precision)
assert q.dtype == k.dtype == v.dtype == dtype
if kv_cache is not None:
cache_k, cache_v = kv_cache
k = jnp.concatenate([cache_k, k], axis=1)
v = jnp.concatenate([cache_v, v], axis=1)
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads)
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
raise ValueError(
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
)
# big_neg = jnp.finfo(logits.dtype).min
big_neg = -2.3819763e38 # See gemma/modules.py
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
out = []
start = 0
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
end = start + x.shape[1]
out_einsum = lora.Einsum(
shape=(config.num_heads, config.head_dim, config.width),
name=_name("attn_vec_einsum", i),
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1),
lora_config=config.lora_configs.get("attn"),
)
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end]))
start = end
else:
out.append(None)
return out, (k, v)
@at.typecheck
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
).astype(dtype)
ff_gate = jnp.dot(x, w_gating[0])
gate_value = nn.gelu(ff_gate)
ff1 = jnp.dot(x, w_gating[1])
activations = gate_value * ff1
w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
).astype(dtype)
outputs = jnp.dot(activations, w_linear)
assert outputs.dtype == dtype
return outputs
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
configs: tuple[Config, ...]
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
@nn.compact
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): # noqa: FBT002
xs = sharding.activation_sharding_constraint(xs)
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x
attn = Attention(configs=self.configs, name="attn")
pre_attn = []
gates = []
for i, x in enumerate(xs):
if x is not None:
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
pre_attn.append(x)
gates.append(gate if x is not None else None)
pre_attn = sharding.activation_sharding_constraint(pre_attn)
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache)
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn)
post_attn = sharding.activation_sharding_constraint(post_attn)
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)]
xs = sharding.activation_sharding_constraint(xs)
out = []
gates = []
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)):
if x is not None:
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) # noqa: PLW2901
x = lora.FeedForward( # noqa: PLW2901
features=config.width,
hidden_dim=config.mlp_dim,
name=_name("mlp", i),
lora_config=config.lora_configs.get("ffn"),
)(x)
out.append(x)
gates.append(gate if x is not None else None)
out = sharding.activation_sharding_constraint(out)
out = jax.tree.map(lambda x: drop(x, deterministic), out)
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)]
xs = sharding.activation_sharding_constraint(xs)
return xs, kv_cache
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]]
@at.typecheck
class Module(nn.Module):
"""Transformer model, supporting a mixture of different weights for different tokens."""
configs: Sequence[Config] # list of configs, one for each expert
embed_dtype: str
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
adarms: bool = False
def setup(self):
# all experts must have the same depth
assert all(config.depth == self.configs[0].depth for config in self.configs)
self.embedder = Embedder(
vocab_size=PALIGEMMA_VOCAB_SIZE,
embed_dim=self.configs[0].width, # embedder for first expert only
name="embedder",
)
block_cls = nn.remat(
Block,
prevent_cse=False,
static_argnums=(5,), # 0=self, 6=deterministic
policy=jax.checkpoint_policies.nothing_saveable,
)
self.layers = nn.scan(
block_cls,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(
0,
nn.broadcast,
nn.broadcast,
nn.broadcast,
nn.broadcast,
), # 0=kv_cache, 1=positions, 2=mask, 3=adarms_cond, 4=deterministic
length=self.configs[0].depth,
)(
configs=self.configs,
dropout=self.dropout,
dropout_bdims=self.dropout_bdims,
)
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))]
@at.typecheck
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]:
return self.embedder.encode(tokens).astype(self.embed_dtype)
@at.typecheck
def __call__(
self,
# list of token arrays, one for each expert, or None if that expert should not be run
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None],
positions: at.Int[at.Array, "b t"],
mask: at.Bool[at.Array, "b t s"],
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None,
*,
kv_cache: KVCache | None = None,
deterministic: bool = True,
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]:
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded)
mask = jnp.asarray(mask)[:, None, :, :]
if adarms_cond is None:
adarms_cond = [None] * len(self.configs)
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic)
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None)
return [
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True)
], kv_cache
def init(self, use_adarms: Sequence[bool]):
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
self.embed(jnp.zeros((1, 1), dtype=jnp.int32))
self(
[jnp.zeros((1, 1, c.width)) for c in self.configs],
jnp.zeros((1, len(self.configs)), dtype=jnp.int32),
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool),
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)],
)
def _apply_rope(x, *, positions, max_wavelength=10_000):
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
assert radians.dtype == jnp.float32
# radians.shape = [...,L,1,d=D/2]
sin, cos = jnp.sin(radians), jnp.cos(radians)
x1, x2 = jnp.split(x, 2, axis=-1)
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
assert res.dtype == jnp.float32
# The original bigvision impl allows RoPE to upcast to float32. It is then immediately downcast again to the cache
# dtype when in inference mode (but not in training mode). I don't think any of this was intentional. Based on the
# original DeepMind impl, as well as the widely-used transformers impl, it is ok to always downcast back to bfloat16
# here.
return res.astype(x.dtype)
def _name(name, i):
# we name layers like this because we want the first expert's weights to have no suffix (e.g., "attn"), so that they
# can be loaded seamlessly from the existing PaliGemma checkpoint. subsequent experts will have a suffix (e.g.,
# "attn_1") and their weights will be initialized from scratch. in practice, we only use two experts -- PaliGemma,
# and the action expert.
if i == 0:
return name
return f"{name}_{i}"
def _gated_residual(x, y, gate):
assert (x is None) == (y is None)
if x is None:
return None
if gate is None:
return x + y
return x + y * gate

View File

@@ -0,0 +1,437 @@
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gemma model implementation from big_vision/models/ppp/gemma.py (with small modifications for NNX compatibility)
Used for FAST autoregressive policies.
"""
import dataclasses
from typing import Literal, TypeAlias
import einops
import flax.linen as nn
import jax
import jax.numpy as jnp
import ml_collections
import openpi.models.lora as lora
import openpi.shared.array_typing as at
Variant = Literal["gemma_2b", "gemma_2b_lora"]
def get_config(variant):
"""Returns config for specified gemma variant."""
if variant == "gemma_2b":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
}
)
if variant == "gemma_2b_lora":
return ml_collections.ConfigDict(
{
"variant": variant,
"width": 2048,
"depth": 18,
"mlp_dim": 16_384,
"num_heads": 8,
"num_kv_heads": 1,
"head_dim": 256,
"norm_eps": 1e-6,
"vocab_size": 257_152,
"scan": True,
"remat_policy": "nothing_saveable",
"lora_configs": {
"attn": lora.LoRAConfig(rank=16, alpha=16.0),
"ffn": lora.LoRAConfig(rank=16, alpha=16.0),
},
}
)
raise ValueError(f"Unknown variant: {variant}")
@at.typecheck
class Einsum(nn.Module):
shape: tuple[int, ...]
@nn.compact
def __call__(self, eqn, x):
dtype = x.dtype # original dtype, could be half-precision
w = self.param("w", nn.initializers.zeros_init(), self.shape).astype(dtype)
return jnp.einsum(eqn, x, w)
@at.typecheck
class RMSNorm(nn.Module):
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1]))
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) # compute variance in float32
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) # compute normalization in float32
normed_inputs = normed_inputs * (
1 + scale
) # scale by learned parameter in float32 (matches Flax implementation)
return normed_inputs.astype(dtype) # return in original dtype
@at.typecheck
class Embedder(nn.Module):
"""Embedder module."""
vocab_size: int
embed_dim: int
def setup(self):
self.input_embedding_table = self.param(
"input_embedding",
nn.initializers.zeros_init(),
(self.vocab_size, self.embed_dim),
)
def encode(self, x):
x = self.input_embedding_table[(x,)]
x *= jnp.sqrt(self.embed_dim).astype(x.dtype)
return x
def decode(self, x):
return jnp.dot(x, self.input_embedding_table.T)
@at.typecheck
class Attention(nn.Module):
"""Attention module."""
num_heads: int
num_kv_heads: int
features: int
head_dim: int
cache_dtype: str | None = None
lora_config: lora.LoRAConfig | None = None
def setup(self):
if self.num_kv_heads == self.num_heads:
self.qkv_einsum = lora.Einsum(
shape=(3, self.num_heads, self.features, self.head_dim),
name="qkv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
else:
self.q_einsum = lora.Einsum(
shape=(self.num_heads, self.features, self.head_dim),
name="q_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
self.kv_einsum = lora.Einsum(
shape=(2, self.num_kv_heads, self.features, self.head_dim),
name="kv_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)),
lora_config=self.lora_config,
)
self.attn_vec_einsum = lora.Einsum(
shape=(self.num_heads, self.head_dim, self.features),
name="attn_vec_einsum",
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
lora_config=self.lora_config,
)
def _init_cache(self, k, v, cache_size):
"""Initialize KV cache"""
prefill_len = k.shape[1]
pad_width = ((0, 0), (0, cache_size - prefill_len), (0, 0), (0, 0))
cache_dtype = self.cache_dtype or k.dtype
k_cache = jnp.pad(k.astype(cache_dtype), pad_width)
v_cache = jnp.pad(v.astype(cache_dtype), pad_width)
idx = jnp.zeros((k.shape[0],), dtype=jnp.int32) + prefill_len
return idx, k_cache, v_cache
def _update_cache(self, k, v, idx, k_cache, v_cache):
"""Update KV cache with new values"""
assert k.shape[1] == 1, "Only support kv-cache updates of length 1"
indices = (0, idx[0], 0, 0)
cache_dtype = self.cache_dtype or k.dtype
k_new = jax.lax.dynamic_update_slice(k_cache, k.astype(cache_dtype), indices)
v_new = jax.lax.dynamic_update_slice(v_cache, v.astype(cache_dtype), indices)
idx_new = idx + 1
return idx_new, k_new, v_new
@nn.compact
def __call__(self, x, positions, attn_mask, kv_cache, decode, deterministic=True): # noqa: FBT002
dtype = x.dtype # original dtype, could be half-precision
if self.num_kv_heads == self.num_heads:
q, k, v = self.qkv_einsum("BSD,3KDH->3BSKH", x)
else:
q = self.q_einsum("BTD,NDH->BTNH", x)
k, v = self.kv_einsum("BSD,2KDH->2BSKH", x)
q = _apply_rope(q, positions=positions) # promotes to float32
q *= self.head_dim**-0.5
k = _apply_rope(k, positions=positions) # promotes to float32
if kv_cache is None:
idx, k_cache, v_cache = self._init_cache(k, v, attn_mask.shape[-1])
else:
idx, k_cache, v_cache = kv_cache
idx, k_cache, v_cache = self._update_cache(k, v, idx, k_cache, v_cache)
k, v = k_cache, v_cache
kv_cache = (idx, k_cache, v_cache)
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.num_kv_heads)
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32)
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]):
raise ValueError(
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}"
)
# big_neg = jnp.finfo(logits.dtype).min
big_neg = -2.3819763e38 # See gemma/modules.py
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg)
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype)
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v)
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H")
return self.attn_vec_einsum("BTNH,NHD->BTD", encoded), kv_cache
@at.typecheck
class Block(nn.Module):
"""Transformer block."""
num_heads: int
num_kv_heads: int
embed_dim: int
head_dim: int
hidden_dim: int
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = ()
cache_dtype: str | None = None
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
def setup(self):
self.pre_attention_norm = RMSNorm()
self.attn = Attention(
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
features=self.embed_dim,
head_dim=self.head_dim,
cache_dtype=self.cache_dtype,
lora_config=self.lora_configs.get("attn"),
)
self.pre_ffw_norm = RMSNorm()
self.mlp = lora.FeedForward(
features=self.embed_dim, hidden_dim=self.hidden_dim, name="mlp", lora_config=self.lora_configs.get("ffn")
)
if self.dropout:
self.drop = nn.Dropout(self.dropout, self.dropout_bdims)
else:
self.drop = lambda x, _: x
def __call__(self, x, kv_cache, positions, attn_mask, decode, deterministic=True): # noqa: FBT002
x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
inputs_normalized = self.pre_attention_norm(x)
attn_output, kv_cache = self.attn(inputs_normalized, positions, attn_mask, kv_cache, decode, deterministic)
attn_output = self.drop(attn_output, deterministic)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)
outputs = self.mlp(attn_output)
outputs = self.drop(outputs, deterministic)
outputs = residual + outputs
return outputs, kv_cache
KVCache: TypeAlias = tuple[at.Int[at.Array, " b"], at.Float[at.Array, "b _t _k _h"], at.Float[at.Array, "b _t _v _h"]]
@at.typecheck
class Module(nn.Module):
"""gemma model."""
variant: str
width: int
depth: int
mlp_dim: int
num_heads: int
num_kv_heads: int
head_dim: int
norm_eps: float
vocab_size: int
embed_dtype: str
dropout: float = 0.0
dropout_bdims: tuple[int, ...] = () # Every float is dropped independently.
cache_dtype: str | None = None
scan: bool = False
remat_policy: str = "none"
lora_configs: ml_collections.ConfigDict = dataclasses.field(default_factory=ml_collections.ConfigDict)
@nn.compact
def __call__(
self,
tokens=None,
embedded_prefix=None,
embed_only=False, # noqa: FBT002
pre_logits=None,
positions=None,
mask=None,
decode=False, # noqa: FBT002
kv_cache=None,
deterministic=True, # noqa: FBT002
return_prelogits=False, # noqa: FBT002
):
"""Embed only, or complete forward pass.
Args:
tokens: Embedded, then and appended to `embedded_prefix`. Can be None.
embedded_prefix: Optional prefix that is already embedded.
embed_only: Whether to compute embeddings only.
pre_logits: If present computes logits from pre_logits and returns.
positions: Optional `[B, T]` allows to specify the absolute position of
the tokens.
mask: Optional attention mask `[B, T, S]`.
decode: Whether to use kv-cache. Caller must pass masks and positions.
deterministic: Forwarded to all dropout layers.
return_prelogits: Whether to return the pre-logits.
Returns:
If `embed_only=False`, then `(logits, out)` will be returned.
If `embed_only=True`, then the embeddings will be returned.
If `return_prelogits=True`, then the pre-logits will be returned.
"""
out = {}
embedder = Embedder(vocab_size=self.vocab_size, embed_dim=self.width, name="embedder")
if pre_logits is not None:
x = out["pre_logits"] = pre_logits
logits = out["logits"] = embedder.decode(x)
return logits, out
x = []
if embedded_prefix is not None:
x.append(embedded_prefix)
if tokens is not None:
x.append(embedder.encode(tokens))
x = jnp.concatenate(x, axis=-2)
x = x.astype(self.embed_dtype)
batch_size, seq_len, width = x.shape
if embed_only:
return x
if decode:
assert positions is not None and mask is not None, ( # noqa: PT018
"Must explicitly pass positions and mask for decoding."
)
if positions is None:
positions = jnp.arange(seq_len).astype(jnp.int32)[None, :]
assert positions.shape[1] == x.shape[1], (positions.shape, x.shape)
if mask is None:
mask = nn.attention.make_causal_mask(jnp.ones([batch_size, seq_len]))
if mask.ndim == 3:
mask = mask[:, None, :, :]
cache_size = max(seq_len, mask.shape[-1])
assert mask.shape == (batch_size, 1, seq_len, cache_size), mask.shape
if self.remat_policy == "none":
block_cls = Block
else:
block_cls = nn.remat(
Block,
prevent_cse=not self.scan,
static_argnums=(5, 6), # 0=self, 5=decode, 6=deterministic
policy=getattr(jax.checkpoint_policies, self.remat_policy),
)
block_kw = {
"num_heads": self.num_heads,
"head_dim": self.head_dim,
"num_kv_heads": self.num_kv_heads,
"embed_dim": width,
"hidden_dim": self.mlp_dim,
"dropout": self.dropout,
"dropout_bdims": self.dropout_bdims,
"cache_dtype": self.cache_dtype,
"lora_configs": self.lora_configs,
}
layers = self.scope.push("layers")
blocks = [
nn.scan(
block_cls,
variable_axes={"params": 0},
split_rngs={"params": True, "dropout": True},
in_axes=(0, nn.broadcast, nn.broadcast, nn.broadcast, nn.broadcast), # 0=kv_cache, 1=positions, 2=mask
length=self.depth,
)(parent=layers, **block_kw)
]
for block in blocks:
x, kv_cache = block(x, kv_cache, positions, mask, decode, deterministic)
assert x.dtype == jnp.dtype(self.embed_dtype) # Sanity check.
out["encoded"] = x
x = RMSNorm(name="final_norm")(x)
out["pre_logits"] = x
if return_prelogits:
return x, kv_cache, out
x = embedder.decode(x)
out["logits"] = x
return x, kv_cache, out
def init(self):
"""Convenience method for initializing all parameters, necessary due to the quirks of linen."""
self(jnp.zeros((1, 1), dtype=jnp.int32))
def _apply_rope(x, *, positions, max_wavelength=10_000):
"""Applies RoPE positions [B, L] to x [B, L, H, D]."""
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32)
timescale = max_wavelength**freq_exponents
radians = positions[..., None] / timescale[None, None, :]
radians = radians[..., None, :]
assert radians.dtype == jnp.float32
# radians.shape = [...,L,1,d=D/2]
sin, cos = jnp.sin(radians), jnp.cos(radians)
x1, x2 = jnp.split(x, 2, axis=-1)
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
assert res.dtype == jnp.float32
return res

View File

@@ -0,0 +1,148 @@
import math
import re
import flax.linen as nn
import flax.struct as struct
import jax.numpy as jnp
import openpi.shared.array_typing as at
@struct.dataclass
class LoRAConfig:
"""Configuration for LoRA."""
# LoRA rank.
rank: int
# LoRA scaling factor.
alpha: float = 1.0
# Initialization function for LoRA parameters.
init_fn: nn.initializers.Initializer = nn.initializers.normal(stddev=0.01)
# Enable rank-stabilized LoRA: https://arxiv.org/pdf/2312.03732
rslora: bool = False
# Axes in the weight to apply LoRA to. Should typically be the last two axes.
axes: tuple[int, int] = (-2, -1)
# Axis label which is used by LoRA in einsum equations. Must not be present in the original equation.
label: str = "L"
@property
def scaling_value(self) -> float:
return self.alpha / math.sqrt(self.rank) if self.rslora else self.alpha / self.rank
class Einsum(nn.Module):
"""Einsum with LoRA support. Can be used as a drop-in replacement for the Gemma Einsum."""
# Shape of the weight.
shape: tuple[int, ...]
# Initialization function for the weight.
init_fn: nn.initializers.Initializer = nn.initializers.zeros
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w = self.param("w", self.init_fn, self.shape)
if config := self.lora_config:
# Setup LoRA parameters.
shape_a, shape_b = list(self.shape), list(self.shape)
shape_a[config.axes[1]] = config.rank
shape_b[config.axes[0]] = config.rank
self.w_a = self.param("lora_a", config.init_fn, shape_a)
self.w_b = self.param("lora_b", config.init_fn, shape_b)
@nn.compact
def __call__(self, eqn: str, x):
dtype = x.dtype # original dtype, could be half-precision
result = jnp.einsum(eqn, x, self.w.astype(dtype))
if config := self.lora_config:
eqn_a, eqn_b = self._make_lora_eqns(eqn)
lora = jnp.einsum(eqn_a, x, self.w_a.astype(dtype))
lora = jnp.einsum(eqn_b, lora, self.w_b.astype(dtype))
result = result + lora * config.scaling_value
return result
def _make_lora_eqns(self, eqn: str) -> tuple[str, str]:
if "L" in eqn:
raise ValueError(f"L already in eqn: {eqn}")
if not (m := re.match("(.*),(.*)->(.*)", eqn)):
raise ValueError(f"Unsupported einsum eqn: {eqn}")
lhs, rhs, out = m.groups()
assert self.lora_config is not None
a_label, b_label = (rhs[x] for x in self.lora_config.axes)
label = self.lora_config.label
a_rhs = rhs.replace(b_label, label)
a_out = out.replace(b_label, label)
eqn_a = f"{lhs},{a_rhs}->{a_out}"
b_rhs = rhs.replace(a_label, label)
eqn_b = f"{a_out},{b_rhs}->{out}"
return eqn_a, eqn_b
class FeedForward(nn.Module):
"""Feed forward module."""
features: int
hidden_dim: int
# If not None, apply LoRA to the weight.
lora_config: LoRAConfig | None = None
def setup(self):
self.w_gating = self.param(
"gating_einsum",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)),
(2, self.features, self.hidden_dim),
)
self.w_linear = self.param(
"linear",
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1),
(self.hidden_dim, self.features),
)
self.w_gating_lora = None
self.w_linear_lora = None
if self.lora_config:
# Setup LoRA parameters.
# TODO: follow up with a simplified init_fn api.
self.w_gating_lora = (
self.param("gating_einsum_lora_a", self.lora_config.init_fn, (2, self.features, self.lora_config.rank)),
self.param(
"gating_einsum_lora_b", self.lora_config.init_fn, (2, self.lora_config.rank, self.hidden_dim)
),
)
self.w_linear_lora = (
self.param("linear_lora_a", self.lora_config.init_fn, (self.hidden_dim, self.lora_config.rank)),
self.param("linear_lora_b", self.lora_config.init_fn, (self.lora_config.rank, self.features)),
)
@nn.compact
def __call__(self, x):
dtype = x.dtype # original dtype, could be half-precision
ff_gate = self._dot(
x,
self.w_gating[0],
None if self.w_gating_lora is None else (self.w_gating_lora[0][0], self.w_gating_lora[1][0]),
)
gate_value = nn.gelu(ff_gate)
ff1 = self._dot(
x,
self.w_gating[1],
None if self.w_gating_lora is None else (self.w_gating_lora[0][1], self.w_gating_lora[1][1]),
)
activations = gate_value * ff1
outputs = self._dot(activations, self.w_linear, self.w_linear_lora)
assert outputs.dtype == dtype
return outputs
def _dot(self, x: at.Array, w: at.Array, lora_weights: tuple[at.Array, at.Array] | None) -> at.Array:
base = jnp.dot(x, w.astype(x.dtype))
if lora_weights is None:
return base
return base + jnp.dot(jnp.dot(x, lora_weights[0].astype(x.dtype)), lora_weights[1].astype(x.dtype))

View File

@@ -0,0 +1,94 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import openpi.models.lora as lora
def test_lora_einsum_params_shape():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
lora0 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2))
lora1 = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, axes=(1, 2)))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
# Ensure that lora parameters are not initialized when LoRA is not used.
params = einsum.init(key, eqn, x)
assert "lora_a" not in params["params"]
assert "lora_b" not in params["params"]
# Check that default axes work.
params_lora0 = lora0.init(key, eqn, x)
assert params_lora0["params"]["lora_a"].shape == (3, 8, 32, 2)
assert params_lora0["params"]["lora_b"].shape == (3, 8, 2, 4)
# Check that user provided axes work.
params_lora1 = lora1.init(key, eqn, x)
assert params_lora1["params"]["lora_a"].shape == (3, 8, 2, 4)
assert params_lora1["params"]["lora_b"].shape == (3, 2, 32, 4)
def test_lora_einsum_same_output():
shape = (3, 8, 32, 4) # (3KDH)
einsum = lora.Einsum(shape)
einsum_lora = lora.Einsum(shape, lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros))
key = jax.random.key(0)
x = jax.random.normal(key, (8, 64, 32)) # (BSD)
eqn = "BSD,3KDH->3BSKH"
params = einsum.init(key, eqn, x)
output = einsum.apply(params, eqn, x)
params_lora = einsum_lora.init(key, eqn, x)
output_lora = einsum_lora.apply(params_lora, eqn, x)
# Results are the same since the LoRA parameters are initialized to zeros.
assert jnp.allclose(output, output_lora)
def test_lora_ffn_params_shape():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
assert params["params"]["gating_einsum"].shape == (2, 8, 32)
assert params["params"]["linear"].shape == (32, 8)
params_lora = ffn_lora.init(key, x)
assert params_lora["params"]["gating_einsum"].shape == (2, 8, 32)
assert params_lora["params"]["linear"].shape == (32, 8)
assert params_lora["params"]["gating_einsum_lora_a"].shape == (2, 8, 2)
assert params_lora["params"]["gating_einsum_lora_b"].shape == (2, 2, 32)
assert params_lora["params"]["linear_lora_a"].shape == (32, 2)
assert params_lora["params"]["linear_lora_b"].shape == (2, 8)
def test_lora_ffn_same_output():
ffn = lora.FeedForward(features=8, hidden_dim=32)
ffn_lora = lora.FeedForward(
features=8,
hidden_dim=32,
lora_config=lora.LoRAConfig(rank=2, init_fn=nn.initializers.zeros),
)
key = jax.random.key(0)
x = jax.random.normal(key, (2, 8))
params = ffn.init(key, x)
output = ffn.apply(params, x)
params_lora = ffn_lora.init(key, x)
output_lora = ffn_lora.apply(params_lora, x)
assert jnp.allclose(output, output_lora)

View File

@@ -0,0 +1,332 @@
import abc
from collections.abc import Sequence
import dataclasses
import enum
import logging
import pathlib
from typing import Generic, TypeVar
import augmax
from flax import nnx
from flax import struct
from flax import traverse_util
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
from openpi.models_pytorch import pi0_pytorch
from openpi.shared import image_tools
import openpi.shared.array_typing as at
logger = logging.getLogger("openpi")
# Type variable for array types (JAX arrays, PyTorch tensors, or numpy arrays)
ArrayT = TypeVar("ArrayT", bound=jax.Array | torch.Tensor | np.ndarray)
class ModelType(enum.Enum):
"""Supported model types."""
PI0 = "pi0"
PI0_FAST = "pi0_fast"
PI05 = "pi05"
# The model always expects these images
IMAGE_KEYS = (
"base_0_rgb",
"left_wrist_0_rgb",
"right_wrist_0_rgb",
)
# This may need change if we release a small model.
IMAGE_RESOLUTION = (224, 224)
# Data format
#
# Data transforms produce the model input as a nested dictionary which is later converted
# into `Obesrvation` and `Actions` objects. See below.
#
# In the dictory form, this data should look like:
# {
# # Observation data.
# "image": {
# "base_0_rgb": (float32|uint8)[*b, h, w, 3], # RGB image in [-1, 1] or [0, 255]
# ... # Additional camera views
# },
# "image_mask": {
# "base_0_rgb": bool[*b], # True if image is valid
# ... # Masks for additional views
# },
# "state": float32[*b, s], # Low-dimensional robot state
# "tokenized_prompt": int32[*b, l], # Optional, tokenized language prompt
# "tokenized_prompt_mask": bool[*b, l], # Optional, mask for tokenized prompt
# "token_ar_mask": int32[*b, l], # Optional, autoregressive mask for FAST model
# "token_loss_mask": bool[*b, l], # Optional, loss mask for FAST model
#
# # Actions data.
# "actions": float32[*b ah ad]
# }
# where:
# *b = batch dimensions
# h,w = image height/width
# s = state dimension
# l = sequence length
#
@at.typecheck
@struct.dataclass
class Observation(Generic[ArrayT]):
"""Holds observations, i.e., inputs to the model.
See `Observation.from_dict` to see the expected dictionary form. This is the format
that should be produced by the data transforms.
"""
# Images, in [-1, 1] float32.
images: dict[str, at.Float[ArrayT, "*b h w c"]]
# Image masks, with same keys as images.
image_masks: dict[str, at.Bool[ArrayT, "*b"]]
# Low-dimensional robot state.
state: at.Float[ArrayT, "*b s"]
# Tokenized prompt.
tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None
# Tokenized prompt mask.
tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None
# pi0-fast model specific fields.
# Token auto-regressive mask (for FAST autoregressive model).
token_ar_mask: at.Int[ArrayT, "*b l"] | None = None
# Token loss mask (for FAST autoregressive model).
token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None
@classmethod
def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]":
"""This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format."""
# Ensure that tokenized_prompt and tokenized_prompt_mask are provided together.
if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data):
raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.")
# If images are uint8, convert them to [-1, 1] float32.
for key in data["image"]:
if data["image"][key].dtype == np.uint8:
data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0
elif hasattr(data["image"][key], "dtype") and data["image"][key].dtype == torch.uint8:
data["image"][key] = data["image"][key].to(torch.float32).permute(0, 3, 1, 2) / 255.0 * 2.0 - 1.0
return cls(
images=data["image"],
image_masks=data["image_mask"],
state=data["state"],
tokenized_prompt=data.get("tokenized_prompt"),
tokenized_prompt_mask=data.get("tokenized_prompt_mask"),
token_ar_mask=data.get("token_ar_mask"),
token_loss_mask=data.get("token_loss_mask"),
)
def to_dict(self) -> at.PyTree[ArrayT]:
"""Convert the Observation to a nested dict."""
result = dataclasses.asdict(self)
result["image"] = result.pop("images")
result["image_mask"] = result.pop("image_masks")
return result
# Defines the format of the actions. This field is included as "actions" inside the dictionary
# produced by the data transforms.
Actions = at.Float[ArrayT, "*b ah ad"]
def preprocess_observation(
rng: at.KeyArrayLike | None,
observation: Observation,
*,
train: bool = False,
image_keys: Sequence[str] = IMAGE_KEYS,
image_resolution: tuple[int, int] = IMAGE_RESOLUTION,
) -> Observation:
"""Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and
filling in a default image mask (if necessary).
"""
if not set(image_keys).issubset(observation.images):
raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}")
batch_shape = observation.state.shape[:-1]
out_images = {}
for key in image_keys:
image = observation.images[key]
if image.shape[1:3] != image_resolution:
logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}")
image = image_tools.resize_with_pad(image, *image_resolution)
if train:
# Convert from [-1, 1] to [0, 1] for augmax.
image = image / 2.0 + 0.5
transforms = []
if "wrist" not in key:
height, width = image.shape[1:3]
transforms += [
augmax.RandomCrop(int(width * 0.95), int(height * 0.95)),
augmax.Resize(width, height),
augmax.Rotate((-5, 5)),
]
transforms += [
augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5),
]
sub_rngs = jax.random.split(rng, image.shape[0])
image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image)
# Back to [-1, 1].
image = image * 2.0 - 1.0
out_images[key] = image
# obtain mask
out_masks = {}
for key in out_images:
if key not in observation.image_masks:
# do not mask by default
out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool)
else:
out_masks[key] = jnp.asarray(observation.image_masks[key])
return Observation(
images=out_images,
image_masks=out_masks,
state=observation.state,
tokenized_prompt=observation.tokenized_prompt,
tokenized_prompt_mask=observation.tokenized_prompt_mask,
token_ar_mask=observation.token_ar_mask,
token_loss_mask=observation.token_loss_mask,
)
@dataclasses.dataclass(frozen=True)
class BaseModelConfig(abc.ABC):
"""Configuration shared by all models. Specific models should inherit from this class, and implement the `create`
method to create the corresponding model.
"""
# Action space dimension.
action_dim: int
# Action sequence length.
action_horizon: int
# Tokenized prompt maximum length.
max_token_len: int
@property
@abc.abstractmethod
def model_type(self) -> ModelType:
"""The model type."""
@abc.abstractmethod
def create(self, rng: at.KeyArrayLike) -> "BaseModel":
"""Create a new model, initializing parameters."""
def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel":
"""Create a model with the given parameters."""
model = nnx.eval_shape(self.create, jax.random.key(0))
graphdef, state = nnx.split(model)
if remove_extra_params:
params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params)
at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False)
state.replace_by_pure_dict(params)
return nnx.merge(graphdef, state)
def load_pytorch(self, train_config, weight_path: str):
logger.info(f"train_config: {train_config}")
model = pi0_pytorch.PI0Pytorch(config=train_config.model)
safetensors.torch.load_model(model, weight_path)
return model
@abc.abstractmethod
def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]:
"""Returns the input specification for the model. Values are jax.ShapeDtypeStruct."""
def fake_obs(self, batch_size: int = 1) -> Observation:
observation_spec, _ = self.inputs_spec(batch_size=batch_size)
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec)
def fake_act(self, batch_size: int = 1) -> Actions:
_, action_spec = self.inputs_spec(batch_size=batch_size)
return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec)
@dataclasses.dataclass
class BaseModel(nnx.Module, abc.ABC):
"""Base class for all model implementations. Specific models should inherit from this class. They should call
super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len).
"""
action_dim: int
action_horizon: int
max_token_len: int
@abc.abstractmethod
def compute_loss(
self,
rng: at.KeyArrayLike,
observation: Observation,
actions: Actions,
*,
train: bool = False,
) -> at.Float[at.Array, "*b ah"]: ...
@abc.abstractmethod
def sample_actions(self, rng: at.KeyArrayLike, observation: Observation, **kwargs) -> Actions: ...
def restore_params(
params_path: pathlib.Path | str,
*,
restore_type: type[np.ndarray] | type[jax.Array] = jax.Array,
dtype: jnp.dtype | None = None,
sharding: jax.sharding.Sharding | None = None,
) -> at.Params:
"""Restores unstructured params PyTree from a checkpoint.
This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as
well as pre-trained checkpoints released for openpi.
Args:
params_path: The local path to the checkpoint directory.
restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array.
dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint.
sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices.
Returns:
The restored params.
"""
params_path = pathlib.Path(params_path).resolve() if not str(params_path).startswith("gs://") else params_path
if restore_type is jax.Array and sharding is None:
mesh = jax.sharding.Mesh(jax.devices(), ("x",))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
with ocp.PyTreeCheckpointer() as ckptr:
metadata = ckptr.metadata(params_path)
item = {"params": metadata["params"]}
params = ckptr.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree.map(
lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item
),
),
)["params"]
# If the params were saved with `save_state` during openpi training, every key path will end with "value", which is
# added by `nnx.State`. We remove the "value" suffix here and always return what NNX calls a "pure dict".
flat_params = traverse_util.flatten_dict(params)
if all(kp[-1] == "value" for kp in flat_params):
flat_params = {kp[:-1]: v for kp, v in flat_params.items()}
return traverse_util.unflatten_dict(flat_params)

View File

@@ -0,0 +1,94 @@
from flax import nnx
import jax
import pytest
from openpi.models import model as _model
from openpi.models import pi0_config
from openpi.models import pi0_fast
from openpi.shared import download
from openpi.shared import nnx_utils
def test_pi0_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_lora_model():
key = jax.random.key(0)
config = pi0_config.Pi0Config(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)
def test_pi0_fast_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig()
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
def test_pi0_fast_lora_model():
key = jax.random.key(0)
config = pi0_fast.Pi0FASTConfig(paligemma_variant="gemma_2b_lora")
model = config.create(key)
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
loss = nnx_utils.module_jit(model.compute_loss)(key, obs, act)
assert loss.shape == (batch_size,)
actions = nnx_utils.module_jit(model.sample_actions)(key, obs)
assert actions.shape == (batch_size, 256)
lora_filter = nnx_utils.PathRegex(".*lora.*")
model_state = nnx.state(model)
lora_state_elems = list(model_state.filter(lora_filter))
assert len(lora_state_elems) > 0
@pytest.mark.manual
def test_model_restore():
key = jax.random.key(0)
config = pi0_config.Pi0Config()
batch_size = 2
obs, act = config.fake_obs(batch_size), config.fake_act(batch_size)
model = config.load(
_model.restore_params(download.maybe_download("gs://openpi-assets/checkpoints/pi0_base/params"))
)
loss = model.compute_loss(key, obs, act)
assert loss.shape == (batch_size, config.action_horizon)
actions = model.sample_actions(key, obs, num_steps=10)
assert actions.shape == (batch_size, model.action_horizon, model.action_dim)

Some files were not shown because too many files have changed in this diff Show More