Add OpenPi, Pi0 and Pi0.5 (#1910)

* initial commit

* change device in test

* do detailed import

* adhere to python 3.11 syntax

* fix autodocstring

* additionally

* do same in other files

* add model. prefix to all keys in state dict

* use dummy stats

* add pi05

* also shorten action_steps

* fix test

* all test pass! and fix tokenizer max length between 05 and 0

* remove test

* fix transformer dependency

* fix test

* split pi0 and pi05 policy in seperate files

* fix test

* fix push to hub test

* add some comments, license and readme

* remove warning in config

* add pi05 to factory

* remove check

* rename action_horizon to chunk_size

* clean up padding of state and action (more in line with lerobot pi0)

* add openpi image transforms for training and add more flexibility to _preprocess_images similar to lerobot pi0

* fix key match from pytorch state dict (similar keys to openpi implementation now)

* also for pi05

* update to python 3.11

* revert to openpi transformer replace python 3.11

* fix(modeling pi0): nit  warning message

* use safeauto_docstring

* fix: remove unused param

* fix from pretrained

* add preprocess tests

* also compile forward method

* Do not add model prefix to normalization

* use same name for action and state dim as lerobot pi0 and remove fixed image keys

* load from pretrained_path

* temp: hardcode base model

* fix override self.pretrained_path = None overwrite

* rename to loss

* remove additional image augmentations, lerobot dataset already does this

* Add docs

* put tests in test folder

* Add test to instatiate all base models

* go back to python 3.10

* update docs

* adapt docs pi05

* change docs: finetune base model options

* minor docs fixes and dependencies

* remove todo

* cast float64 to float32 for mps

* skip if no transformers

* fix tests

* add new models to modelcard

* add back init

* fix circular input

* feat: only run pi test on GPU

* remove require_nightly_gpu

* replace decorator test_pi0_openpi

* rename action_dim, state_dim to max_action_dim, max_state_dim

* fix doc and constants

* cleanup tests

* fix from pretrained

* fix tests

* add comment pi0 pi05 tests, add image features to pi0 pi05 hub tests

* fix, state is included in language not in flow head

* Move test to specific folder

* and paligemma task with newline

* remove add_special_tokens, not needed

* feedback pr

* Remove previous pi0 and rename pi0_openpi and pi05_openpi

* Add Quantile stats to LeRobotDataset (#1985)

* - Add RunningQuantileStats class for efficient histogram-based quantile computation
- Integrate quantile parameters (compute_quantiles, quantiles) into LeRobotDataset
- Support quantile computation during episode collection and aggregation
- Add comprehensive function-based test suite (24 tests) for quantile functionality
- Maintain full backward compatibility with existing stats computation
- Enable configurable quantiles (default: [0.01, 0.99]) for robust normalization

* style fixes, make quantiles computation by default to new datasets

* fix tests

* - Added DEFAULT_QUANTILES=[0.01, 0.10, 0.50, 0.90, 0.99] to be computed for each features instead of being chosen by the user
- Fortified tests.

* - add helper functions to reshape stats
- add missing test for quantiles

* - Add QUANTILE normalization mode to normalize the data with the 1st and 99th percentiles.
- Add QUANTILE10 normalization mode to normalize the data with the 10th and 90th percentiles.

* style fixes

* Added missing lisence

* Simplify compute_stats

* - added script `augment_dataset_quantile_stats.py` so that we can add quantile stats to existing v3 datasets that dont have quatniles
- modified quantile computation instead of using the edge for the value, interpolate the values in the bin

* rename pi0/pi05 files

* Remove open pi patch and use custom transformer branch for now

* renaming

* fix

* Revert "fix"

This reverts commit 1ea65730ac2cbca6e5869df734fbd4392561b3c6.

* fix naming

* feet(pi0/pi0.5): add pipeline (#2009)

* feat(processor): convert openpi model with processor

* TODO: Make test works

* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests

- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.

* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy

- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.

* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration

- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.

* feat(processor): convert openpi model with processor

* TODO: Make test works

* fix(modeling_pi0openpi): update attention mask value and time scaling; improve task handling in tests

- Changed the attention mask value from `self.config.attention_mask_value` to a fixed value of `-2.3819763e38`.
- Updated time scaling in the `sample_noise` method to use a constant factor of `0.999` and an offset of `0.001`.
- Enhanced task handling in tests to ensure proper formatting and batch size consistency.
- Cleaned up commented-out test code for clarity.

* refactor(pi0): rename PI0OpenPIConfig and PI0OpenPIPolicy to PI0Config and PI0Policy

- Updated imports and references throughout the codebase to reflect the new naming convention.
- Introduced a new processor file for PI0 to handle pre-processing and post-processing steps.
- Adjusted tests to utilize the renamed classes, ensuring consistency and functionality.
- Enhanced clarity and maintainability by removing outdated naming conventions.

* refactor(pi05): rename PI0OpenPIPolicy to PI0Policy and update configuration

- Renamed `PI0OpenPIPolicy` to `PI0Policy` for consistency with naming conventions.
- Updated the `PI05OpenPIConfig` to include a new `tokenizer_max_length` attribute and changed the normalization mode for state from `MEAN_STD` to `QUANTILES`.
- Simplified model initialization in `PI05OpenPIPolicy` by removing unused `dataset_stats` parameter.
- Added a new processor class for `Pi05PrepareStateTokenizerProcessorStep` with `@dataclass` for improved readability.
- Introduced a test script to compare the integration of the PI0OpenPI policy with the original implementation, ensuring local testing compatibility.

* refactor(pi05): update imports and rename configuration classes

- Changed imports to reflect the new naming convention for PI05 configuration and policy classes.
- Renamed `PI05OpenPIConfig` to `PI05Config` and `PI05OpenPIPolicy` to `PI05Policy` for consistency.
- Introduced a new processor file for PI05, implementing pre-processing and post-processing steps.
- Updated tests to utilize the renamed classes, ensuring functionality and consistency across the codebase.

* update(pi05): increase tokenizer_max_length for improved processing

- Changed the `tokenizer_max_length` from 48 to 200 to enhance the model's capability in handling longer sequences.
- This adjustment aims to improve the overall performance and flexibility of the PI05 configuration.

* add default for state (max_state_dim)

* correct naming

* fix import

* cleanup code

* remove unused test

* us quantiles for action

* move to device

* remove discrete state assert

* fix pi05 test

* move pi05 to device

* use base models in comparison tests

* small renames for tests

* change number of tokens pi05 test

* fix openpi tokenization in test

* fix hub test

* fix test

* assert lerobot vs openpi tests

---------

Co-authored-by: Pepijn <pepijn@huggingface.co>

* add headers

* add back previously removed imports

* update if statement load processor with dataset stats

* remove to avoid circular import

* inject dataset stats for pretrained models

* check normalization before applying

* add link to  quantile augument script

* fix(policies): transformers import for ci in PI0 & PI05 (#2039)

* fix(policies): transformers import for ci in PI0

* fix(policies): transformers import for ci in PI05

* test(processor): fix expected raise when normalization types are missing (#2040)

* switch normalization order pipeline for pi05

* Fix/quantiles script (#2064)

* refactor augment stats with quantiles script
add parallelization for faster processing
shift the quantile normalization between -1 1

* fix replay buffer tests

* fix comment

* overwrite the pipeline normalization features with the policy features

* remove double normalization overwrite

* cleanup from pretrained

* remove typo

* also set norm_map

* fix(augment_quantiles) images incorrectly divided by 255

* clamp quantiles

* link to lerobot base models

* rename tests

* encorperate PR feedback

* update docstring for RunningQuantileStats

* update doc links

* Revert "clamp quantiles"

This reverts commit 172207471c8f2cb62958e9a9e6a0535ba3ff67d4.

* fix self.paligemma

* fix tests related to quantiles that were scaled to [0,1], the new range is [-1, 1]

* fix libero doc and use different transformer branch

* use fix branch instead of feat

* update results libero

* add new line

* fix formatting

* precommit

* update results libero

* update libero doc

* update title

* final changes

* add quantiles to test

* run pre commit

---------

Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
Pepijn
2025-10-02 13:14:45 +02:00
committed by GitHub
parent b6c528a438
commit abde7be3b3
43 changed files with 5886 additions and 2288 deletions

View File

@@ -28,11 +28,14 @@
title: "Datasets"
- sections:
- local: smolvla
title: Finetune SmolVLA
title: SmolVLA
- local: pi0
title: π₀ (Pi0)
- local: pi05
title: π₀.₅ (Pi05)
- local: libero
title: Using Libero
title: "Policies"
- sections:
- local: introduction_processors
title: Introduction to Robot Processors

View File

@@ -125,3 +125,42 @@ lerobot-train \
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
## Reproducing π₀.₅ results
We reproduce the results of π₀.₅ on the LIBERO benchmark using the LeRobot implementation. We take the Physical Intelligence LIBERO base model (`pi05_libero`) and finetune for an additional 6k steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero).
The finetuned model can be found here:
- **π₀.₅ LIBERO**: [lerobot/pi05_libero_finetuned](https://huggingface.co/lerobot/pi05_libero_finetuned)
We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command:
```bash
python src/lerobot/scripts/eval.py \
--output_dir=/logs/ \
--env.type=libero \
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--policy.path=pi05_libero_finetuned \
--policy.n_action_steps=10 \
--output_dir=./eval_logs/ \
--env.max_parallel_tasks=1
```
**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation.
### Results
We obtain the following results on the LIBERO benchmark:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | -------- |
| **π₀.₅** | 97.0 | 99.0 | 98.0 | 96.0 | **97.5** |
These results are consistent with the original [results](https://github.com/Physical-Intelligence/openpi/tree/main/examples/libero#results) reported by Physical Intelligence:
| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average |
| -------- | -------------- | ------------- | ----------- | --------- | --------- |
| **π₀.₅** | 98.8 | 98.2 | 98.0 | 92.4 | **96.85** |

79
docs/source/pi0.mdx Normal file
View File

@@ -0,0 +1,79 @@
# π₀ (Pi0)
π₀ is a **Vision-Language-Action model for general robot control**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
## Model Overview
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
### The Vision for Physical Intelligence
As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models.
### Architecture and Approach
π₀ combines several key innovations:
- **Flow Matching**: Uses a novel method to augment pre-trained VLMs with continuous action outputs via flow matching (a variant of diffusion models)
- **Cross-Embodiment Training**: Trained on data from 8 distinct robot platforms including UR5e, Bimanual UR5e, Franka, Bimanual Trossen, Bimanual ARX, Mobile Trossen, and Mobile Fibocom
- **Internet-Scale Pre-training**: Inherits semantic knowledge from a pre-trained 3B parameter Vision-Language Model
- **High-Frequency Control**: Outputs motor commands at up to 50 Hz for real-time dexterous manipulation
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install Pi0 dependencies by running:
```bash
pip install -e ".[pi]"
```
## Training Data and Capabilities
π₀ is trained on the largest robot interaction dataset to date, combining three key data sources:
1. **Internet-Scale Pre-training**: Vision-language data from the web for semantic understanding
2. **Open X-Embodiment Dataset**: Open-source robot manipulation datasets
3. **Physical Intelligence Dataset**: Large and diverse dataset of dexterous tasks across 8 distinct robots
## Usage
To use π₀ in LeRobot, specify the policy type as:
```python
policy.type=pi0
```
## Training
For training π₀, you can use the standard LeRobot training script with the appropriate configuration:
```bash
python src/lerobot/scripts/train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
--job_name=pi0_training \
--policy.pretrained_path=lerobot/pi0_base \
--policy.repo_id=your_repo_id \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
### Key Training Parameters
- **`--policy.compile_model=true`**: Enables model compilation for faster training
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
- **`--policy.pretrained_path=lerobot/pi0_base`**: The base π₀ model you want to finetune, options are:
- [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base)
- [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset)
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

98
docs/source/pi05.mdx Normal file
View File

@@ -0,0 +1,98 @@
# π₀.₅ (Pi05) Policy
π₀.₅ is a **Vision-Language-Action model with open-world generalization**, from Physical Intelligence. The LeRobot implementation is adapted from their open source [OpenPI](https://github.com/Physical-Intelligence/openpi) repository.
## Model Overview
π₀.₅ represents a significant evolution from π₀, developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi05) to address a big challenge in robotics: **open-world generalization**. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
### The Generalization Challenge
As Physical Intelligence explains, the fundamental challenge isn't performing tasks of agility or dexterity, but generalization, the ability to correctly perform tasks in new settings with new objects. Consider a robot cleaning different homes: each home has different objects in different places. Generalization must occur at multiple levels:
- **Physical Level**: Understanding how to pick up a spoon (by the handle) or plate (by the edge), even with unseen objects in cluttered environments
- **Semantic Level**: Understanding task semantics, where to put clothes and shoes (laundry hamper, not on the bed), and what tools are appropriate for cleaning spills
- **Environmental Level**: Adapting to "messy" real-world environments like homes, grocery stores, offices, and hospitals
### Co-Training on Heterogeneous Data
The breakthrough innovation in π₀.₅ is **co-training on heterogeneous data sources**. The model learns from:
1. **Multimodal Web Data**: Image captioning, visual question answering, object detection
2. **Verbal Instructions**: Humans coaching robots through complex tasks step-by-step
3. **Subtask Commands**: High-level semantic behavior labels (e.g., "pick up the pillow" for an unmade bed)
4. **Cross-Embodiment Robot Data**: Data from various robot platforms with different capabilities
5. **Multi-Environment Data**: Static robots deployed across many different homes
6. **Mobile Manipulation Data**: ~400 hours of mobile robot demonstrations
This diverse training mixture creates a "curriculum" that enables generalization across physical, visual, and semantic levels simultaneously.
## Installation Requirements
1. Install LeRobot by following our [Installation Guide](./installation).
2. Install Pi0.5 dependencies by running:
```bash
pip install -e ".[pi]"
```
## Usage
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
```python
policy.type=pi05
```
## Training
### Training Command Example
Here's a complete training command for finetuning the base π₀.₅ model on your own dataset:
```bash
python src/lerobot/scripts/train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi05 \
--output_dir=./outputs/pi0_training \
--job_name=pi0_training \
--policy.repo_id=lerobot/pi05_base \
--policy.pretrained_path=your_repo_id \
--policy.compile_model=true \
--policy.gradient_checkpointing=true \
--wandb.enable=true \
--policy.dtype=bfloat16 \
--steps=3000 \
--policy.device=cuda \
--batch_size=32
```
### Key Training Parameters
- **`--policy.compile_model=true`**: Enables model compilation for faster training
- **`--policy.gradient_checkpointing=true`**: Reduces memory usage significantly during training
- **`--policy.dtype=bfloat16`**: Use mixed precision training for efficiency
- **`--batch_size=32`**: Batch size for training, adapt this based on your GPU memory
- **`--policy.pretrained_path=lerobot/pi05_base`**: The base π₀.₅ model you want to finetune, options are:
- [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base)
- [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset)
## Performance Results
### Libero Benchmark Results
π₀.₅ has demonstrated strong performance on the Libero benchmark suite. To compare and test its LeRobot implementation, we finetuned the libero base model for an additional 6k steps on the Libero dataset and compared the results to the OpenPI reference results.
| Benchmark | LeRobot Implementation | OpenPI Reference |
| ------------------ | ---------------------- | ---------------- |
| **Libero Spatial** | 97.0% | 98.8% |
| **Libero Object** | 99.0% | 98.2% |
| **Libero Goal** | 98.0% | 98.0% |
| **Libero 10** | 96.0% | 92.4% |
| **Average** | 97.5% | 96.85% |
These results demonstrate π₀.₅'s strong generalization capabilities across diverse robotic manipulation tasks. To reproduce these results, you can follow the instructions in the [Libero](https://huggingface.co/docs/lerobot/libero) section.
## License
This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -1,4 +1,4 @@
# Finetune SmolVLA
# SmolVLA
SmolVLA is Hugging Faces lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development!

View File

@@ -94,7 +94,7 @@ dependencies = [
# Common
pygame-dep = ["pygame>=2.5.1"]
placo-dep = ["placo>=0.9.6"]
transformers-dep = ["transformers>=4.52.0"]
transformers-dep = ["transformers>=4.53.0"]
grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"]
# Motors
@@ -119,7 +119,7 @@ phone = ["hebi-py>=2.8.0", "teleop>=0.1.0"]
# ] # TODO: Currently not supported
# Policies
pi0 = ["lerobot[transformers-dep]"]
pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"]
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.11", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
@@ -147,7 +147,7 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
"lerobot[pi]",
"lerobot[smolvla]",
"lerobot[hilserl]",
"lerobot[async]",

View File

@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet"]
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]

View File

@@ -25,7 +25,14 @@ from lerobot.configs.types import PolicyFeature
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.policies import ( # noqa: F401
ACTConfig,
DiffusionConfig,
PI0Config,
PI05Config,
SmolVLAConfig,
VQBeTConfig,
)
from lerobot.robots.robot import Robot
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import init_logging

View File

@@ -71,9 +71,11 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
tags: list[str] | None = None
# Add tags to your policy on the hub.
license: str | None = None
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
pretrained_path: str | None = None
def __post_init__(self):
self.pretrained_path = None
if not self.device or not is_torch_device_available(self.device):
auto_device = auto_select_torch_device()
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")

View File

@@ -35,6 +35,8 @@ class NormalizationMode(str, Enum):
MIN_MAX = "MIN_MAX"
MEAN_STD = "MEAN_STD"
IDENTITY = "IDENTITY"
QUANTILES = "QUANTILES"
QUANTILE10 = "QUANTILE10"
@dataclass

View File

@@ -17,6 +17,179 @@ import numpy as np
from lerobot.datasets.utils import load_image_as_numpy
DEFAULT_QUANTILES = [0.01, 0.10, 0.50, 0.90, 0.99]
class RunningQuantileStats:
"""
Maintains running statistics for batches of vectors, including mean,
standard deviation, min, max, and approximate quantiles.
Statistics are computed per feature dimension and updated incrementally
as new batches are observed. Quantiles are estimated using histograms,
which adapt dynamically if the observed data range expands.
"""
def __init__(self, quantile_list: list[float] | None = None, num_quantile_bins: int = 5000):
self._count = 0
self._mean = None
self._mean_of_squares = None
self._min = None
self._max = None
self._histograms = None
self._bin_edges = None
self._num_quantile_bins = num_quantile_bins
self._quantile_list = quantile_list
if self._quantile_list is None:
self._quantile_list = DEFAULT_QUANTILES
self._quantile_keys = [f"q{int(q * 100):02d}" for q in self._quantile_list]
def update(self, batch: np.ndarray) -> None:
"""Update the running statistics with a batch of vectors.
Args:
batch: An array where all dimensions except the last are batch dimensions.
"""
batch = batch.reshape(-1, batch.shape[-1])
num_elements, vector_length = batch.shape
if self._count == 0:
self._mean = np.mean(batch, axis=0)
self._mean_of_squares = np.mean(batch**2, axis=0)
self._min = np.min(batch, axis=0)
self._max = np.max(batch, axis=0)
self._histograms = [np.zeros(self._num_quantile_bins) for _ in range(vector_length)]
self._bin_edges = [
np.linspace(self._min[i] - 1e-10, self._max[i] + 1e-10, self._num_quantile_bins + 1)
for i in range(vector_length)
]
else:
if vector_length != self._mean.size:
raise ValueError("The length of new vectors does not match the initialized vector length.")
new_max = np.max(batch, axis=0)
new_min = np.min(batch, axis=0)
max_changed = np.any(new_max > self._max)
min_changed = np.any(new_min < self._min)
self._max = np.maximum(self._max, new_max)
self._min = np.minimum(self._min, new_min)
if max_changed or min_changed:
self._adjust_histograms()
self._count += num_elements
batch_mean = np.mean(batch, axis=0)
batch_mean_of_squares = np.mean(batch**2, axis=0)
# Update running mean and mean of squares
self._mean += (batch_mean - self._mean) * (num_elements / self._count)
self._mean_of_squares += (batch_mean_of_squares - self._mean_of_squares) * (
num_elements / self._count
)
self._update_histograms(batch)
def get_statistics(self) -> dict[str, np.ndarray]:
"""Compute and return the statistics of the vectors processed so far.
Args:
quantiles: List of quantiles to compute (e.g., [0.01, 0.10, 0.50, 0.90, 0.99]). If None, no quantiles computed.
Returns:
Dictionary containing the computed statistics.
"""
if self._count < 2:
raise ValueError("Cannot compute statistics for less than 2 vectors.")
variance = self._mean_of_squares - self._mean**2
stddev = np.sqrt(np.maximum(0, variance))
stats = {
"min": self._min.copy(),
"max": self._max.copy(),
"mean": self._mean.copy(),
"std": stddev,
"count": np.array([self._count]),
}
quantile_results = self._compute_quantiles()
for i, q in enumerate(self._quantile_keys):
stats[q] = quantile_results[i]
return stats
def _adjust_histograms(self):
"""Adjust histograms when min or max changes."""
for i in range(len(self._histograms)):
old_edges = self._bin_edges[i]
old_hist = self._histograms[i]
# Create new edges with small padding to ensure range coverage
padding = (self._max[i] - self._min[i]) * 1e-10
new_edges = np.linspace(
self._min[i] - padding, self._max[i] + padding, self._num_quantile_bins + 1
)
# Redistribute existing histogram counts to new bins
# We need to map each old bin center to the new bins
old_centers = (old_edges[:-1] + old_edges[1:]) / 2
new_hist = np.zeros(self._num_quantile_bins)
for old_center, count in zip(old_centers, old_hist, strict=False):
if count > 0:
# Find which new bin this old center belongs to
bin_idx = np.searchsorted(new_edges, old_center) - 1
bin_idx = max(0, min(bin_idx, self._num_quantile_bins - 1))
new_hist[bin_idx] += count
self._histograms[i] = new_hist
self._bin_edges[i] = new_edges
def _update_histograms(self, batch: np.ndarray) -> None:
"""Update histograms with new vectors."""
for i in range(batch.shape[1]):
hist, _ = np.histogram(batch[:, i], bins=self._bin_edges[i])
self._histograms[i] += hist
def _compute_quantiles(self) -> list[np.ndarray]:
"""Compute quantiles based on histograms."""
results = []
for q in self._quantile_list:
target_count = q * self._count
q_values = []
for hist, edges in zip(self._histograms, self._bin_edges, strict=True):
q_value = self._compute_single_quantile(hist, edges, target_count)
q_values.append(q_value)
results.append(np.array(q_values))
return results
def _compute_single_quantile(self, hist: np.ndarray, edges: np.ndarray, target_count: float) -> float:
"""Compute a single quantile value from histogram and bin edges."""
cumsum = np.cumsum(hist)
idx = np.searchsorted(cumsum, target_count)
if idx == 0:
return edges[0]
if idx >= len(cumsum):
return edges[-1]
# If not edge case, interpolate within the bin
count_before = cumsum[idx - 1]
count_in_bin = cumsum[idx] - count_before
# If no samples in this bin, use the bin edge
if count_in_bin == 0:
return edges[idx]
# Linear interpolation within the bin
fraction = (target_count - count_before) / count_in_bin
return edges[idx] + fraction * (edges[idx + 1] - edges[idx])
def estimate_num_samples(
dataset_len: int, min_num_samples: int = 100, max_num_samples: int = 10_000, power: float = 0.75
@@ -72,33 +245,282 @@ def sample_images(image_paths: list[str]) -> np.ndarray:
return images
def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[str, np.ndarray]:
return {
"min": np.min(array, axis=axis, keepdims=keepdims),
"max": np.max(array, axis=axis, keepdims=keepdims),
"mean": np.mean(array, axis=axis, keepdims=keepdims),
"std": np.std(array, axis=axis, keepdims=keepdims),
"count": np.array([len(array)]),
def _reshape_stats_by_axis(
stats: dict[str, np.ndarray],
axis: int | tuple[int, ...] | None,
keepdims: bool,
original_shape: tuple[int, ...],
) -> dict[str, np.ndarray]:
"""Reshape all statistics to match NumPy's output conventions.
Applies consistent reshaping to all statistics (except 'count') based on the
axis and keepdims parameters. This ensures statistics have the correct shape
for broadcasting with the original data.
Args:
stats: Dictionary of computed statistics
axis: Axis or axes along which statistics were computed
keepdims: Whether to keep reduced dimensions as size-1 dimensions
original_shape: Shape of the original array
Returns:
Dictionary with reshaped statistics
Note:
The 'count' statistic is never reshaped as it represents metadata
rather than per-feature statistics.
"""
if axis == (1,) and not keepdims:
return stats
result = {}
for key, value in stats.items():
if key == "count":
result[key] = value
else:
result[key] = _reshape_single_stat(value, axis, keepdims, original_shape)
return result
def _reshape_for_image_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
"""Reshape statistics for image data (axis=(0,2,3))."""
if keepdims and value.ndim == 1:
return value.reshape(1, -1, 1, 1)
return value
def _reshape_for_vector_stats(
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray:
"""Reshape statistics for vector data (axis=0 or axis=(0,))."""
if not keepdims:
return value
if len(original_shape) == 1 and value.ndim > 0:
return value.reshape(1)
elif len(original_shape) >= 2 and value.ndim == 1:
return value.reshape(1, -1)
return value
def _reshape_for_feature_stats(value: np.ndarray, keepdims: bool) -> np.ndarray:
"""Reshape statistics for feature-wise computation (axis=(1,))."""
if not keepdims:
return value
if value.ndim == 0:
return value.reshape(1, 1)
elif value.ndim == 1:
return value.reshape(-1, 1)
return value
def _reshape_for_global_stats(
value: np.ndarray, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
"""Reshape statistics for global reduction (axis=None)."""
if keepdims:
target_shape = tuple(1 for _ in original_shape)
return value.reshape(target_shape)
# Keep at least 1-D arrays to satisfy validator
return np.atleast_1d(value)
def _reshape_single_stat(
value: np.ndarray, axis: int | tuple[int, ...] | None, keepdims: bool, original_shape: tuple[int, ...]
) -> np.ndarray | float:
"""Apply appropriate reshaping to a single statistic array.
This function transforms statistic arrays to match expected output shapes
based on the axis configuration and keepdims parameter.
Args:
value: The statistic array to reshape
axis: Axis or axes that were reduced during computation
keepdims: Whether to maintain reduced dimensions as size-1 dimensions
original_shape: Shape of the original data before reduction
Returns:
Reshaped array following NumPy broadcasting conventions
"""
if axis == (0, 2, 3):
return _reshape_for_image_stats(value, keepdims)
if axis in [0, (0,)]:
return _reshape_for_vector_stats(value, keepdims, original_shape)
if axis == (1,):
return _reshape_for_feature_stats(value, keepdims)
if axis is None:
return _reshape_for_global_stats(value, keepdims, original_shape)
return value
def _prepare_array_for_stats(array: np.ndarray, axis: int | tuple[int, ...] | None) -> tuple[np.ndarray, int]:
"""Prepare array for statistics computation by reshaping according to axis.
Args:
array: Input data array
axis: Axis or axes along which to compute statistics
Returns:
Tuple of (reshaped_array, sample_count)
"""
if axis == (0, 2, 3): # Image data
batch_size, channels, height, width = array.shape
reshaped = array.transpose(0, 2, 3, 1).reshape(-1, channels)
return reshaped, batch_size
if axis == 0 or axis == (0,): # Vector data
reshaped = array
if array.ndim == 1:
reshaped = array.reshape(-1, 1)
return reshaped, array.shape[0]
if axis == (1,): # Feature-wise statistics
return array.T, array.shape[1]
if axis is None: # Global statistics
reshaped = array.reshape(-1, 1)
# For backward compatibility, count represents the first dimension size
return reshaped, array.shape[0] if array.ndim > 0 else 1
raise ValueError(f"Unsupported axis configuration: {axis}")
def _compute_basic_stats(
array: np.ndarray, sample_count: int, quantile_list: list[float] | None = None
) -> dict[str, np.ndarray]:
"""Compute basic statistics for arrays with insufficient samples for quantiles.
Args:
array: Reshaped array ready for statistics computation
sample_count: Number of samples represented in the data
Returns:
Dictionary with basic statistics and quantiles set to mean values
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
quantile_list_keys = [f"q{int(q * 100):02d}" for q in quantile_list]
stats = {
"min": np.min(array, axis=0),
"max": np.max(array, axis=0),
"mean": np.mean(array, axis=0),
"std": np.std(array, axis=0),
"count": np.array([sample_count]),
}
for q in quantile_list_keys:
stats[q] = stats["mean"].copy()
return stats
def get_feature_stats(
array: np.ndarray,
axis: int | tuple[int, ...] | None,
keepdims: bool,
quantile_list: list[float] | None = None,
) -> dict[str, np.ndarray]:
"""Compute comprehensive statistics for array features along specified axes.
This function calculates min, max, mean, std, and quantiles (1%, 10%, 50%, 90%, 99%)
for the input array along the specified axes. It handles different data layouts:
- Image data: axis=(0,2,3) computes per-channel statistics
- Vector data: axis=0 computes per-feature statistics
- Feature-wise: axis=1 computes statistics across features
- Global: axis=None computes statistics over entire array
Args:
array: Input data array with shape appropriate for the specified axis
axis: Axis or axes along which to compute statistics
- (0, 2, 3): For image data (batch, channels, height, width)
- 0 or (0,): For vector/tabular data (samples, features)
- (1,): For computing across features
- None: For global statistics over entire array
keepdims: If True, reduced axes are kept as dimensions with size 1
Returns:
Dictionary containing:
- 'min': Minimum values
- 'max': Maximum values
- 'mean': Mean values
- 'std': Standard deviation
- 'count': Number of samples (always shape (1,))
- 'q01', 'q10', 'q50', 'q90', 'q99': Quantile values
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
original_shape = array.shape
reshaped, sample_count = _prepare_array_for_stats(array, axis)
if reshaped.shape[0] < 2:
stats = _compute_basic_stats(reshaped, sample_count, quantile_list)
else:
running_stats = RunningQuantileStats()
running_stats.update(reshaped)
stats = running_stats.get_statistics()
stats["count"] = np.array([sample_count])
stats = _reshape_stats_by_axis(stats, axis, keepdims, original_shape)
return stats
def compute_episode_stats(
episode_data: dict[str, list[str] | np.ndarray],
features: dict,
quantile_list: list[float] | None = None,
) -> dict:
"""Compute comprehensive statistics for all features in an episode.
Processes different data types appropriately:
- Images/videos: Samples from paths, computes per-channel stats, normalizes to [0,1]
- Numerical arrays: Computes per-feature statistics
- Strings: Skipped (no statistics computed)
Args:
episode_data: Dictionary mapping feature names to data
- For images/videos: list of file paths
- For numerical data: numpy arrays
features: Dictionary describing each feature's dtype and shape
Returns:
Dictionary mapping feature names to their statistics dictionaries.
Each statistics dictionary contains min, max, mean, std, count, and quantiles.
Note:
Image statistics are normalized to [0,1] range and have shape (3,1,1) for
per-channel values when dtype is 'image' or 'video'.
"""
if quantile_list is None:
quantile_list = DEFAULT_QUANTILES
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
continue
if features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data)
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
ep_ft_array = data
axes_to_reduce = 0
keepdims = data.ndim == 1
ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)
ep_stats[key] = get_feature_stats(
ep_ft_array, axis=axes_to_reduce, keepdims=keepdims, quantile_list=quantile_list
)
# finally, we normalize and remove batch dim for images
if features[key]["dtype"] in ["image", "video"]:
ep_stats[key] = {
k: v if k == "count" else np.squeeze(v / 255.0, axis=0) for k, v in ep_stats[key].items()
@@ -107,20 +529,37 @@ def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], featu
return ep_stats
def _validate_stat_value(value: np.ndarray, key: str, feature_key: str) -> None:
"""Validate a single statistic value."""
if not isinstance(value, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{key}' of feature '{feature_key}' "
f"is of type '{type(value)}' instead."
)
if value.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if key == "count" and value.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {value.shape} instead.")
if "image" in feature_key and key != "count" and value.shape != (3, 1, 1):
raise ValueError(f"Shape of quantile '{key}' must be (3,1,1), but is {value.shape} instead.")
def _assert_type_and_shape(stats_list: list[dict[str, dict]]):
for i in range(len(stats_list)):
for fkey in stats_list[i]:
for k, v in stats_list[i][fkey].items():
if not isinstance(v, np.ndarray):
raise ValueError(
f"Stats must be composed of numpy array, but key '{k}' of feature '{fkey}' is of type '{type(v)}' instead."
)
if v.ndim == 0:
raise ValueError("Number of dimensions must be at least 1, and is 0 instead.")
if k == "count" and v.shape != (1,):
raise ValueError(f"Shape of 'count' must be (1), but is {v.shape} instead.")
if "image" in fkey and k != "count" and v.shape != (3, 1, 1):
raise ValueError(f"Shape of '{k}' must be (3,1,1), but is {v.shape} instead.")
"""Validate that all statistics have correct types and shapes.
Args:
stats_list: List of statistics dictionaries to validate
Raises:
ValueError: If any statistic has incorrect type or shape
"""
for stats in stats_list:
for feature_key, feature_stats in stats.items():
for stat_key, stat_value in feature_stats.items():
_validate_stat_value(stat_value, stat_key, feature_key)
def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
@@ -143,7 +582,7 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
weighted_variances = (variances + delta_means**2) * counts
total_variance = weighted_variances.sum(axis=0) / total_count
return {
aggregated = {
"min": np.min(np.stack([s["min"] for s in stats_ft_list]), axis=0),
"max": np.max(np.stack([s["max"] for s in stats_ft_list]), axis=0),
"mean": total_mean,
@@ -151,6 +590,17 @@ def aggregate_feature_stats(stats_ft_list: list[dict[str, dict]]) -> dict[str, d
"count": total_count,
}
if stats_ft_list:
quantile_keys = [k for k in stats_ft_list[0] if k.startswith("q") and k[1:].isdigit()]
for q_key in quantile_keys:
if all(q_key in s for s in stats_ft_list):
quantile_values = np.stack([s[q_key] for s in stats_ft_list])
weighted_quantiles = quantile_values * counts
aggregated[q_key] = weighted_quantiles.sum(axis=0) / total_count
return aggregated
def aggregate_stats(stats_list: list[dict[str, dict]]) -> dict[str, dict[str, np.ndarray]]:
"""Aggregate stats from multiple compute_stats outputs into a single set of stats.

View File

@@ -0,0 +1,225 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script augments existing LeRobot datasets with quantile statistics.
Most datasets created before the quantile feature was added do not contain
quantile statistics (q01, q10, q50, q90, q99) in their metadata. This script:
1. Loads an existing LeRobot dataset in v3.0 format
2. Checks if it already contains quantile statistics
3. If missing, computes quantile statistics for all features
4. Updates the dataset metadata with the new quantile statistics
Usage:
```bash
python src/lerobot/datasets/v30/augment_dataset_quantile_stats.py \
--repo-id=lerobot/pusht \
```
"""
import argparse
import concurrent.futures
import logging
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES, aggregate_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import write_stats
from lerobot.utils.utils import init_logging
def has_quantile_stats(stats: dict[str, dict] | None, quantile_list_keys: list[str] | None = None) -> bool:
"""Check if dataset statistics already contain quantile information.
Args:
stats: Dataset statistics dictionary
Returns:
True if quantile statistics are present, False otherwise
"""
if quantile_list_keys is None:
quantile_list_keys = [f"q{int(q * 100):02d}" for q in DEFAULT_QUANTILES]
if stats is None:
return False
for feature_stats in stats.values():
if any(q_key in feature_stats for q_key in quantile_list_keys):
return True
return False
def process_single_episode(dataset: LeRobotDataset, episode_idx: int) -> dict:
"""Process a single episode and return its statistics.
Args:
dataset: The LeRobot dataset
episode_idx: Index of the episode to process
Returns:
Dictionary containing episode statistics
"""
logging.info(f"Computing stats for episode {episode_idx}")
start_idx = dataset.meta.episodes[episode_idx]["dataset_from_index"]
end_idx = dataset.meta.episodes[episode_idx]["dataset_to_index"]
ep_stats = {}
for key, data in dataset.hf_dataset[start_idx:end_idx].items():
if dataset.features[key]["dtype"] == "string":
continue
data = torch.stack(data).cpu().numpy()
if dataset.features[key]["dtype"] in ["image", "video"]:
axes_to_reduce = (0, 2, 3)
keepdims = True
else:
axes_to_reduce = 0
keepdims = data.ndim == 1
ep_stats[key] = get_feature_stats(
data, axis=axes_to_reduce, keepdims=keepdims, quantile_list=DEFAULT_QUANTILES
)
if dataset.features[key]["dtype"] in ["image", "video"]:
for k, v in ep_stats[key].items():
if dataset.features[key]["dtype"] == "video":
v = v / 255.0
if k != "count":
v = np.squeeze(v, axis=0)
ep_stats[key][k] = v
return ep_stats
def compute_quantile_stats_for_dataset(dataset: LeRobotDataset) -> dict[str, dict]:
"""Compute quantile statistics for all episodes in the dataset.
Args:
dataset: The LeRobot dataset to compute statistics for
Returns:
Dictionary containing aggregated statistics with quantiles
"""
logging.info(f"Computing quantile statistics for dataset with {dataset.num_episodes} episodes")
episode_stats_list = []
max_workers = min(dataset.num_episodes, 16)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_episode = {
executor.submit(process_single_episode, dataset, episode_idx): episode_idx
for episode_idx in range(dataset.num_episodes)
}
episode_results = {}
with tqdm(total=dataset.num_episodes, desc="Processing episodes") as pbar:
for future in concurrent.futures.as_completed(future_to_episode):
episode_idx = future_to_episode[future]
ep_stats = future.result()
episode_results[episode_idx] = ep_stats
pbar.update(1)
for episode_idx in range(dataset.num_episodes):
if episode_idx in episode_results:
episode_stats_list.append(episode_results[episode_idx])
if not episode_stats_list:
raise ValueError("No episode data found for computing statistics")
logging.info(f"Aggregating statistics from {len(episode_stats_list)} episodes")
return aggregate_stats(episode_stats_list)
def augment_dataset_with_quantile_stats(
repo_id: str,
root: str | Path | None = None,
overwrite: bool = False,
) -> None:
"""Augment a dataset with quantile statistics if they are missing.
Args:
repo_id: Repository ID of the dataset
root: Local root directory for the dataset
overwrite: Overwrite existing quantile statistics if they already exist
"""
logging.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(
repo_id=repo_id,
root=root,
)
if not overwrite and has_quantile_stats(dataset.meta.stats):
logging.info("Dataset already contains quantile statistics. No action needed.")
return
logging.info("Dataset does not contain quantile statistics. Computing them now...")
new_stats = compute_quantile_stats_for_dataset(dataset)
logging.info("Updating dataset metadata with new quantile statistics")
dataset.meta.stats = new_stats
write_stats(new_stats, dataset.meta.root)
logging.info("Successfully updated dataset with quantile statistics")
dataset.push_to_hub()
def main():
"""Main function to run the augmentation script."""
parser = argparse.ArgumentParser(description="Augment LeRobot dataset with quantile statistics")
parser.add_argument(
"--repo-id",
type=str,
required=True,
help="Repository ID of the dataset (e.g., 'lerobot/pusht')",
)
parser.add_argument(
"--root",
type=str,
help="Local root directory for the dataset",
)
parser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing quantile statistics if they already exist",
)
args = parser.parse_args()
root = Path(args.root) if args.root else None
init_logging()
augment_dataset_with_quantile_stats(
repo_id=args.repo_id,
root=root,
overwrite=args.overwrite,
)
if __name__ == "__main__":
main()

View File

@@ -15,7 +15,7 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .pi05.configuration_pi05 import PI05Config as PI05Config
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -25,6 +25,7 @@ __all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"PI05Config",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",

View File

@@ -32,6 +32,7 @@ from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -81,14 +82,18 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
return PI05Policy
elif name == "sac":
from lerobot.policies.sac.modeling_sac import SACPolicy
@@ -132,10 +137,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi05":
return PI05Config(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
elif policy_type == "smolvla":
@@ -253,6 +260,14 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
processors = make_pi0fast_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
@@ -261,10 +276,10 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
elif isinstance(policy_cfg, PI05Config):
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors
processors = make_pi0fast_pre_post_processors(
processors = make_pi05_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)

View File

@@ -0,0 +1,49 @@
# π₀ (pi0)
This repository contains the Hugging Face port of **π₀**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model for general robot control**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{black2024pi0visionlanguageactionflowmodel,
title = {π₀: A Vision-Language-Action Flow Model for General Robot Control},
author = {Kevin Black and Noah Brown and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Lucy Xiaoyang Shi and James Tanner and Quan Vuong and Anna Walling and Haohuan Wang and Ury Zhilinsky},
year = {2024},
eprint = {2410.24164},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2410.24164},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi0 import PI0Config
from .modeling_pi0 import PI0Policy
from .processor_pi0 import make_pi0_pre_post_processors
__all__ = ["PI0Config", "PI0Policy", "make_pi0_pre_post_processors"]

View File

@@ -1,4 +1,6 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,20 +19,40 @@ from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10 # Number of denoising steps during inference
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
# Normalization
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
@@ -39,94 +61,75 @@ class PI0Config(PreTrainedConfig):
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
# Optimizer settings: see openpi `AdamW``
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
tokenizer_max_length: int = 48 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"{OBS_IMAGES}.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):

View File

@@ -1,82 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
torch.backends.cudnn.benchmark = True
def main():
device = "cuda"
dataset_repo_id = "danaaubakirova/koch_test"
# model_name = "pi0_base"
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = "lerobot/pi0"
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(iter(dataloader))
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
warmup_iters = 10
benchmark_iters = 30
# Warmup
for _ in range(warmup_iters):
torch.cuda.synchronize()
policy.select_action(batch)
policy.reset()
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(benchmark_iters):
policy.select_action(batch)
policy.reset()
end_event.record()
# Synchronize and measure time
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_per_iter = elapsed_time_ms / benchmark_iters
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -1,132 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
from pathlib import Path
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy
from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE
def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool:
tensor = tensor.float()
print(f"Shape: {tensor.shape}")
print(f"Mean: {tensor.mean().item()}")
print(f"Std: {tensor.std().item()}")
print(f"Min: {tensor.min().item()}")
print(f"Max: {tensor.max().item()}")
def main():
num_motors = 14
device = "cuda"
# model_name = "pi0_aloha_towel"
model_name = "pi0_aloha_sim"
if model_name == "pi0_aloha_towel":
dataset_repo_id = "lerobot/aloha_static_towel"
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
with open(save_dir / "example.pkl", "rb") as f:
example = pickle.load(f)
with open(save_dir / "outputs.pkl", "rb") as f:
outputs = pickle.load(f)
with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
norm_stats = json.load(f)
# Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats[OBS_STATE]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats[OBS_STATE]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"{OBS_IMAGES}.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch[OBS_STATE] = torch.from_numpy(example["state"])
batch[ACTION] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
del batch[f"{OBS_IMAGES}.cam_low"]
elif model_name == "pi0_aloha_sim":
batch[f"{OBS_IMAGES}.top"] = batch[f"{OBS_IMAGES}.cam_high"]
del batch[f"{OBS_IMAGES}.cam_high"]
# Batchify
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], str):
batch[key] = [batch[key]]
else:
raise ValueError(f"{key}, {batch[key]}")
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
from lerobot import policies # noqa
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
# print("losses")
# display(loss_dict["losses_after_forward"])
# print("pi_losses")
# display(pi_losses)
actions = []
for _ in range(50):
action = policy.select_action(batch, noise=noise)
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch[ACTION]
print("actions")
display(actions)
print()
print("pi_actions")
display(pi_actions)
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
if __name__ == "__main__":
main()

View File

@@ -1,84 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import GemmaConfig, PaliGemmaConfig
def get_paligemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
image_size = 224 # image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config
def get_gemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 1024,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 4096,
"is_encoder_decoder": False,
}
final_config = GemmaConfig()
final_config.update(text_config)
return final_config

View File

@@ -1,437 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
```bash
cd ~/code/openpi
source .venv/bin/activate
```
Example downloading parameters:
```bash
python
>>> import openpi.shared.download as download
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
>>> download.maybe_download(path)
```
Converting pi0_base:
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
```
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
```
"""
import argparse
import pathlib
import jax
import numpy as np
import orbax.checkpoint as ocp
import torch
from jax.sharding import SingleDeviceSharding
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
get_gemma_config,
get_paligemma_config,
)
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
def slice_paligemma_state_dict(state_dict, config):
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# fmt: off
# patch embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
3, 2, 0, 1
)
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
# positional embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
# multimodal projector
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
# text decoder (gemma)
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
expert_dict = {}
final_state_dict = {}
for key, value in state_dict.items():
if key not in [
f"llm/final_norm_1/scale{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
]:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, num_expert=1):
# fmt: off
# text decoder (gemma)
# no embedding vector, the expert just has the decoder layers
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
# fmt: on
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def flatten_for_memory(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_memory(v, new_key))
else:
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
return out
def flatten_for_npz(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_npz(v, new_key))
else:
# bf16/f32 here?
out[new_key] = np.array(v)
return out
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
params_path = pathlib.Path(checkpoint_dir).resolve()
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(params_path)
print("Metadata keys:", list(metadata.keys()))
params_name = "params"
item = {params_name: metadata[params_name]}
device = jax.local_devices()[0] # Use the first local device
sharding = SingleDeviceSharding(device)
restored = checkpointer.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree_util.tree_map(
lambda _: ocp.ArrayRestoreArgs(
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
sharding=sharding,
),
item,
),
transforms={},
),
)
params = restored[params_name]
# get params for PaliGemma
pali_params = params["PaliGemma"]
del params["PaliGemma"]
pali_params_flat = flatten_for_npz(pali_params)
return {"paligemma_params": pali_params_flat, "projection_params": params}
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
"""Update dictionary keys by adding a prefix."""
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
type=str,
help="Path to the ocdbt checkpoint",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
# tokenizer is identical to paligemma, it appears
parser.add_argument(
"--tokenizer_hub_id",
default="google/paligemma-3b-pt-224",
type=str,
help="Hub path to the tokenizer to save",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Path to save converted weights to",
)
args = parser.parse_args()
convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir,
precision=args.precision,
tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path,
)

View File

@@ -1,141 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
# @torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_att_heads = 8
num_key_value_heads = 1
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.to(torch.float32)
key_states = key_states.to(torch.float32)
value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
_compile=False,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
block_mask = create_block_mask(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output

File diff suppressed because it is too large Load Diff

View File

@@ -1,420 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi0.flex_attention import flex_attention_forward
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
class PaliGemmaWithExpertConfig(PretrainedConfig):
model_type = "PaliGemmaWithExpertModel"
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
def __init__(
self,
paligemma_config: dict | None = None,
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
if paligemma_config is None:
# Default config from Pi0
self.paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": "float32",
"vocab_size": 257152,
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_fast",
"torch_dtype": "float32",
"vision_use_head": False,
},
)
elif isinstance(self.paligemma_config, dict):
# Override Pi0 default config for PaliGemma
if "model_type" not in gemma_expert_config:
paligemma_config["model_type"] = "paligemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.paligemma_config = cfg_cls(**paligemma_config)
if gemma_expert_config is None:
# Default config from Pi0
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation="gelu_pytorch_tanh",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.48.1",
use_cache=True,
vocab_size=257152,
)
elif isinstance(self.gemma_expert_config, dict):
# Override Pi0 default config for Gemma Expert
if "model_type" not in gemma_expert_config:
gemma_expert_config["model_type"] = "gemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
super().__init__(**kwargs)
def __post_init__(self):
super().__post_init__()
if self.train_expert_only and not self.freeze_vision_encoder:
raise ValueError(
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
)
if self.attention_implementation not in ["eager", "fa2", "flex"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for params in self.paligemma.vision_tower.parameters():
params.requires_grad = False
if self.config.train_expert_only:
self.paligemma.eval()
for params in self.paligemma.parameters():
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.config.train_expert_only:
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
params_to_change_dtype = [
"language_model.model.layers",
"gemma_expert.model.layers",
"vision_tower",
"multi_modal",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.paligemma, "get_image_features"):
return self.paligemma.get_image_features(image)
else:
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: list[torch.FloatTensor] = None,
use_cache: bool | None = None,
fill_kv_cache: bool | None = None,
):
models = [self.paligemma.language_model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.paligemma.config.text_config.num_hidden_layers
head_dim = self.paligemma.config.text_config.head_dim
for layer_idx in range(num_layers):
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
continue
layer = models[i].layers[layer_idx]
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
query_states = apply_rope(query_states, position_ids)
key_states = apply_rope(key_states, position_ids)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
if hidden_states is not None:
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.config.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.config.attention_implementation == "flex":
attention_interface = flex_attention_forward
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output

View File

@@ -0,0 +1,49 @@
# π₀.₅ (pi05)
This repository contains the Hugging Face port of **π₀.₅**, adapted from [OpenPI](https://github.com/Physical-Intelligence/openpi) by the Physical Intelligence.
It is designed as a **Vision-Language-Action model with open-world generalization**.
---
## Model Overview
| Feature | π₀ | π₀.₅ |
| -------------------- | ------------------------------------------------------ | ----------------------------------------- |
| Time Conditioning | Concatenates time with actions via `action_time_mlp_*` | Uses `time_mlp_*` for AdaRMS conditioning |
| AdaRMS | Not used | Used in action expert |
| Tokenizer Length | 48 tokens | 200 tokens |
| Discrete State Input | False (Uses `state_proj` layer) | True |
| Parameter Count | Higher (includes state embedding) | Lower (no state embedding) |
---
## Citation
If you use this work, please cite both **OpenPI** and the π₀.₅ paper:
```bibtex
@misc{openpi2024,
author = {Physical Intelligence Lab},
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
year = {2024},
publisher = {GitHub},
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
license = {Apache-2.0}
}
@misc{intelligence2025pi05visionlanguageactionmodelopenworld,
title = {π₀.₅: a Vision-Language-Action Model with Open-World Generalization},
author = {Physical Intelligence and Kevin Black and Noah Brown and James Darpinian and Karan Dhabalia and Danny Driess and Adnan Esmail and Michael Equi and Chelsea Finn and Niccolo Fusai and Manuel Y. Galliker and Dibya Ghosh and Lachy Groom and Karol Hausman and Brian Ichter and Szymon Jakubczak and Tim Jones and Liyiming Ke and Devin LeBlanc and Sergey Levine and Adrian Li-Bell and Mohith Mothukuri and Suraj Nair and Karl Pertsch and Allen Z. Ren and Lucy Xiaoyang Shi and Laura Smith and Jost Tobias Springenberg and Kyle Stachowicz and James Tanner and Quan Vuong and Homer Walke and Anna Walling and Haohuan Wang and Lili Yu and Ury Zhilinsky},
year = {2025},
eprint = {2504.16054},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
url = {https://arxiv.org/abs/2504.16054},
}
```
---
## License
This port follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi).

View File

@@ -0,0 +1,21 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .configuration_pi05 import PI05Config
from .modeling_pi05 import PI05Policy
from .processor_pi05 import make_pi05_pre_post_processors
__all__ = ["PI05Config", "PI05Policy", "make_pi05_pre_post_processors"]

View File

@@ -0,0 +1,153 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi05")
@dataclass
class PI05Config(PreTrainedConfig):
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
dtype: str = "float32" # Options: "bfloat16", "float32"
n_obs_steps: int = 1
chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon"
n_action_steps: int = 50 # Number of action steps to execute
# Shorter state and action vectors will be padded to these dimensions
max_state_dim: int = 32
max_action_dim: int = 32
# Flow matching parameters: see openpi `PI0Pytorch`
num_inference_steps: int = 10
time_sampling_beta_alpha: float = 1.5
time_sampling_beta_beta: float = 1.0
time_sampling_scale: float = 0.999
time_sampling_offset: float = 0.001
min_period: float = 4e-3
max_period: float = 4.0
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
# Add empty images. Used to add empty cameras when no image features are present.
empty_cameras: int = 0
tokenizer_max_length: int = 200 # see openpi `__post_init__`
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for state
"ACTION": NormalizationMode.QUANTILES, # Pi0.5 uses quantiles for action
}
)
# Training settings
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
compile_model: bool = False # Whether to use torch.compile for model optimization
compile_mode: str = "max-autotune" # Torch compile mode
device: str | None = None # Device to use for the model (None = auto-detect)
# Optimizer settings: see openpi `AdamW`
optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr`
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
# Scheduler settings: see openpi `CosineDecaySchedule`
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
tokenizer_max_length: int = 200 # see openpi `__post_init__`
def __post_init__(self):
super().__post_init__()
# Validate configuration
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})"
)
if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}")
if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]:
raise ValueError(f"Invalid action_expert_variant: {self.action_expert_variant}")
if self.dtype not in ["bfloat16", "float32"]:
raise ValueError(f"Invalid dtype: {self.dtype}")
def validate_features(self) -> None:
"""Validate and set up input/output features."""
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, *self.image_resolution), # Use configured image resolution
)
self.input_features[key] = empty_camera
if "observation.state" not in self.input_features:
state_feature = PolicyFeature(
type=FeatureType.STATE,
shape=(self.max_state_dim,), # Padded to max_state_dim
)
self.input_features["observation.state"] = state_feature
if "action" not in self.output_features:
action_feature = PolicyFeature(
type=FeatureType.ACTION,
shape=(self.max_action_dim,), # Padded to max_action_dim
)
self.output_features["action"] = action_feature
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,171 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import numpy as np
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.processor.core import EnvTransition, TransitionKey
from lerobot.utils.constants import (
OBS_STATE,
POLICY_POSTPROCESSOR_DEFAULT_NAME,
POLICY_PREPROCESSOR_DEFAULT_NAME,
)
@ProcessorStepRegistry.register(name="pi05_prepare_state_tokenizer_processor_step")
@dataclass
class Pi05PrepareStateTokenizerProcessorStep(ProcessorStep):
"""
Processor step to prepare the state and tokenize the language input.
"""
max_state_dim: int = 32
task_key: str = "task"
def __call__(self, transition: EnvTransition) -> EnvTransition:
transition = transition.copy()
state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE)
if state is None:
raise ValueError("State is required for PI05")
tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key)
if tasks is None:
raise ValueError("No task found in complementary data")
# TODO: check if this necessary
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
state = pad_vector(state, self.max_state_dim)
# State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts
# Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
return transition
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
"""
return features
def make_pi05_pre_post_processors(
config: PI05Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)

View File

@@ -1,3 +1,19 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig

View File

@@ -303,6 +303,65 @@ def clean_state_dict(
return new_state_dict
def load_state_dict_with_missing_key_handling(
policy: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
policy_type: str,
known_missing_keys_whitelist: dict[str, list[str]],
) -> list[str]:
"""
Load state dict into policy with graceful handling of missing keys.
This function loads the state dict with strict=False, filters out whitelisted
missing keys, and provides detailed reporting about any issues found.
Args:
policy: The policy model to load the state dict into.
state_dict: The cleaned state dictionary to load.
policy_type: The type of policy (used for whitelist lookup).
known_missing_keys_whitelist: Dictionary mapping policy types to lists of
known acceptable missing keys.
Returns:
List of problematic missing keys that weren't in the whitelist.
"""
# Load the cleaned state dict with strict=False to capture missing/unexpected keys
load_result = policy.load_state_dict(state_dict, strict=False)
# Check for missing keys
missing_keys = load_result.missing_keys
unexpected_keys = load_result.unexpected_keys
# Filter out whitelisted missing keys
policy_type_lower = policy_type.lower()
whitelisted_keys = known_missing_keys_whitelist.get(policy_type_lower, [])
problematic_missing_keys = [key for key in missing_keys if key not in whitelisted_keys]
if missing_keys:
if problematic_missing_keys:
print(f"WARNING: Found {len(problematic_missing_keys)} unexpected missing keys:")
for key in problematic_missing_keys:
print(f" - {key}")
if len(missing_keys) > len(problematic_missing_keys):
whitelisted_missing = [key for key in missing_keys if key in whitelisted_keys]
print(f"INFO: Found {len(whitelisted_missing)} expected missing keys (whitelisted):")
for key in whitelisted_missing:
print(f" - {key}")
if unexpected_keys:
print(f"WARNING: Found {len(unexpected_keys)} unexpected keys:")
for key in unexpected_keys:
print(f" - {key}")
if not missing_keys and not unexpected_keys:
print("Successfully loaded cleaned state dict into policy model (all keys matched)")
else:
print("State dict loaded with some missing/unexpected keys (see details above)")
return problematic_missing_keys
def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[str, PolicyFeature]:
"""
Converts a feature dictionary from the old config format to the new `PolicyFeature` format.
@@ -336,9 +395,45 @@ def convert_features_to_policy_features(features_dict: dict[str, dict]) -> dict[
return converted_features
def display_migration_summary_with_warnings(problematic_missing_keys: list[str]) -> None:
"""
Display final migration summary with warnings about problematic missing keys.
Args:
problematic_missing_keys: List of missing keys that weren't in the whitelist.
"""
if not problematic_missing_keys:
return
print("\n" + "=" * 60)
print("IMPORTANT: MIGRATION COMPLETED WITH WARNINGS")
print("=" * 60)
print(
f"The migration was successful, but {len(problematic_missing_keys)} unexpected missing keys were found:"
)
print()
for key in problematic_missing_keys:
print(f" - {key}")
print()
print("These missing keys may indicate:")
print(" • The model architecture has changed")
print(" • Some components were not properly saved in the original model")
print(" • The migration script needs to be updated for this policy type")
print()
print("What to do next:")
print(" 1. Test your migrated model carefully to ensure it works as expected")
print(" 2. If you encounter issues, please open an issue at:")
print(" https://github.com/huggingface/lerobot/issues")
print(" 3. Include this migration log and the missing keys listed above")
print()
print("If the model works correctly despite these warnings, the missing keys")
print("might be expected for your policy type and can be added to the whitelist.")
print("=" * 60)
def load_model_from_hub(
repo_id: str, revision: str | None = None
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any]]:
) -> tuple[dict[str, torch.Tensor], dict[str, Any], dict[str, Any] | None]:
"""
Downloads and loads a model's state_dict and configs from the Hugging Face Hub.
@@ -348,13 +443,12 @@ def load_model_from_hub(
Returns:
A tuple containing the model's state dictionary, the policy configuration,
and the training configuration.
and the training configuration (None if train_config.json is not found).
"""
# Download files.
safetensors_path = hf_hub_download(repo_id=repo_id, filename="model.safetensors", revision=revision)
config_path = hf_hub_download(repo_id=repo_id, filename="config.json", revision=revision)
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
# Load state_dict
state_dict = load_safetensors(safetensors_path)
@@ -363,8 +457,14 @@ def load_model_from_hub(
with open(config_path) as f:
config = json.load(f)
with open(train_config_path) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
try:
train_config_path = hf_hub_download(repo_id=repo_id, filename="train_config.json", revision=revision)
with open(train_config_path) as f:
train_config = json.load(f)
except FileNotFoundError:
print("train_config.json not found - continuing without training configuration")
return state_dict, config, train_config
@@ -410,8 +510,15 @@ def main():
state_dict = load_safetensors(os.path.join(args.pretrained_path, "model.safetensors"))
with open(os.path.join(args.pretrained_path, "config.json")) as f:
config = json.load(f)
with open(os.path.join(args.pretrained_path, "train_config.json")) as f:
train_config = json.load(f)
# Try to load train_config (optional)
train_config = None
train_config_path = os.path.join(args.pretrained_path, "train_config.json")
if os.path.exists(train_config_path):
with open(train_config_path) as f:
train_config = json.load(f)
else:
print("train_config.json not found - continuing without training configuration")
else:
# Hub repository
state_dict, config, train_config = load_model_from_hub(args.pretrained_path, args.revision)
@@ -488,10 +595,20 @@ def main():
policy_class = get_policy_class(policy_type)
policy = policy_class(policy_config)
# Load the cleaned state dict
policy.load_state_dict(new_state_dict, strict=True)
print("Successfully loaded cleaned state dict into policy model")
# Define whitelist of known missing keys that are acceptable (for example weight tie) for certain policy types
known_missing_keys_whitelist = {
"pi0": ["model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"],
# Add other policy types and their known missing keys here as needed
}
# Load state dict with graceful missing key handling
problematic_missing_keys = load_state_dict_with_missing_key_handling(
policy=policy,
state_dict=new_state_dict,
policy_type=policy_type,
known_missing_keys_whitelist=known_missing_keys_whitelist,
)
policy.to(torch.float32)
# Create preprocessor and postprocessor using the factory
print("Creating preprocessor and postprocessor using make_pre_post_processors...")
preprocessor, postprocessor = make_pre_post_processors(policy_cfg=policy_config, dataset_stats=stats)
@@ -521,7 +638,9 @@ def main():
# Generate and save model card
print("Generating model card...")
# Get metadata from original config
dataset_repo_id = train_config.get("repo_id", "unknown")
dataset_repo_id = "unknown"
if train_config is not None:
dataset_repo_id = train_config.get("repo_id", "unknown")
license = config.get("license", "apache-2.0")
tags = config.get("tags", ["robotics", "lerobot", policy_type]) or ["robotics", "lerobot", policy_type]
@@ -552,25 +671,25 @@ def main():
if create_pr:
# Separate commit description for PR body
commit_description = """🤖 **Automated Policy Migration to PolicyProcessorPipeline**
commit_description = """**Automated Policy Migration to PolicyProcessorPipeline**
This PR migrates your model to the new LeRobot policy format using the modern PolicyProcessorPipeline architecture.
## What Changed
### **New Architecture - PolicyProcessorPipeline**
### **New Architecture - PolicyProcessorPipeline**
Your model now uses external PolicyProcessorPipeline components for data processing instead of built-in normalization layers. This provides:
- **Modularity**: Separate preprocessing and postprocessing pipelines
- **Flexibility**: Easy to swap, configure, and debug processing steps
- **Compatibility**: Works with the latest LeRobot ecosystem
### 🔧 **Normalization Extraction**
### **Normalization Extraction**
We've extracted normalization statistics from your model's state_dict and removed the built-in normalization layers:
- **Extracted patterns**: `normalize_inputs.*`, `unnormalize_outputs.*`, `normalize.*`, `unnormalize.*`, `input_normalizer.*`, `output_normalizer.*`
- **Statistics preserved**: Mean, std, min, max values for all features
- **Clean model**: State dict now contains only core model weights
### 📦 **Files Added**
### **Files Added**
- **preprocessor_config.json**: Configuration for input preprocessing pipeline
- **postprocessor_config.json**: Configuration for output postprocessing pipeline
- **model.safetensors**: Clean model weights without normalization layers
@@ -578,13 +697,13 @@ We've extracted normalization statistics from your model's state_dict and remove
- **train_config.json**: Training configuration
- **README.md**: Updated model card with migration information
### 🚀 **Benefits**
### **Benefits**
- **Backward Compatible**: Your model behavior remains identical
- **Future Ready**: Compatible with latest LeRobot features and updates
- **Debuggable**: Easy to inspect and modify processing steps
- **Portable**: Processors can be shared and reused across models
### 💻 **Usage**
### **Usage**
```python
# Load your migrated model
from lerobot.policies import get_policy_class
@@ -642,6 +761,9 @@ final_action = postprocessor(action)
else:
print(f"\nView the changes at: https://huggingface.co/{hub_repo_id}")
# Display final summary about any problematic missing keys
display_migration_summary_with_warnings(problematic_missing_keys)
if __name__ == "__main__":
main()

View File

@@ -281,8 +281,14 @@ class _NormalizationMixin:
"""
Core logic to apply a normalization or unnormalization transformation to a tensor.
This method selects the appropriate normalization mode (e.g., mean/std, min/max)
based on the feature type and applies the corresponding mathematical operation.
This method selects the appropriate normalization mode based on the feature type
and applies the corresponding mathematical operation.
Normalization Modes:
- MEAN_STD: Centers data around zero with unit variance.
- MIN_MAX: Scales data to [-1, 1] range using actual min/max values.
- QUANTILES: Scales data to [-1, 1] range using 1st and 99th percentiles (q01/q99).
- QUANTILE10: Scales data to [-1, 1] range using 10th and 90th percentiles (q10/q90).
Args:
tensor: The input tensor to transform.
@@ -300,7 +306,12 @@ class _NormalizationMixin:
if norm_mode == NormalizationMode.IDENTITY or key not in self._tensor_stats:
return tensor
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
if norm_mode not in (
NormalizationMode.MEAN_STD,
NormalizationMode.MIN_MAX,
NormalizationMode.QUANTILES,
NormalizationMode.QUANTILE10,
):
raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
@@ -311,7 +322,14 @@ class _NormalizationMixin:
stats = self._tensor_stats[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
if norm_mode == NormalizationMode.MEAN_STD:
mean = stats.get("mean", None)
std = stats.get("std", None)
if mean is None or std is None:
raise ValueError(
"MEAN_STD normalization mode requires mean and std stats, please update the dataset with the correct stats"
)
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
@@ -319,7 +337,14 @@ class _NormalizationMixin:
return tensor * std + mean
return (tensor - mean) / denom
if norm_mode == NormalizationMode.MIN_MAX and "min" in stats and "max" in stats:
if norm_mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min", None)
max_val = stats.get("max", None)
if min_val is None or max_val is None:
raise ValueError(
"MIN_MAX normalization mode requires min and max stats, please update the dataset with the correct stats"
)
min_val, max_val = stats["min"], stats["max"]
denom = max_val - min_val
# When min_val == max_val, substitute the denominator with a small epsilon
@@ -334,6 +359,40 @@ class _NormalizationMixin:
# Map from [min, max] to [-1, 1]
return 2 * (tensor - min_val) / denom - 1
if norm_mode == NormalizationMode.QUANTILES:
q01 = stats.get("q01", None)
q99 = stats.get("q99", None)
if q01 is None or q99 is None:
raise ValueError(
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
)
denom = q99 - q01
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
if inverse:
return (tensor + 1.0) * denom / 2.0 + q01
return 2.0 * (tensor - q01) / denom - 1.0
if norm_mode == NormalizationMode.QUANTILE10:
q10 = stats.get("q10", None)
q90 = stats.get("q90", None)
if q10 is None or q90 is None:
raise ValueError(
"QUANTILE10 normalization mode requires q10 and q90 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
)
denom = q90 - q10
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
if inverse:
return (tensor + 1.0) * denom / 2.0 + q10
return 2.0 * (tensor - q10) / denom - 1.0
# If necessary stats are missing, return input unchanged.
return tensor

View File

@@ -180,7 +180,8 @@ def train(cfg: TrainPipelineConfig):
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
if not (cfg.resume and cfg.policy.pretrained_path):
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -190,17 +191,22 @@ def train(cfg: TrainPipelineConfig):
"normalizer_processor": {
"stats": dataset.meta.stats,
"features": {**policy.config.input_features, **policy.config.output_features},
"norm_map": policy.config.normalization_mapping,
},
}
processor_kwargs["postprocessor_overrides"] = {
postprocessor_kwargs["postprocessor_overrides"] = {
"unnormalizer_processor": {
"stats": dataset.meta.stats,
"features": {**policy.config.input_features, **policy.config.output_features},
"features": policy.config.output_features,
"norm_map": policy.config.normalization_mapping,
},
}
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, **processor_kwargs
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
logging.info("Creating optimizer and scheduler")

View File

@@ -19,10 +19,28 @@
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
{% elif model_name == "pi0" %}
[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer.
{% elif model_name == "pi0fast" %}
[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
{% elif model_name == "pi0" %}
**π₀ (Pi0)**
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
{% elif model_name == "pi05" %}
**π₀.₅ (Pi05) Policy**
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
**Model Overview**
π₀.₅ represents a significant evolution from π₀, developed by Physical Intelligence to address a big challenge in robotics: open-world generalization. While robots can perform impressive tasks in controlled environments, π₀.₅ is designed to generalize to entirely new environments and situations that were never seen during training.
For more details, see the [Physical Intelligence π₀.₅ blog post](https://www.physicalintelligence.company/blog/pi05).
{% elif model_name == "sac" %}
[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments.
{% elif model_name == "reward_classifier" %}

View File

@@ -67,3 +67,6 @@ HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibr
# streaming datasets
LOOKBACK_BACKTRACKTABLE = 100
LOOKAHEAD_BACKTRACKTABLE = 100
# openpi
OPENPI_ATTENTION_MASK_VALUE = -2.3819763e38 # TODO(pepijn): Modify this when extending support to fp8 models

View File

@@ -19,6 +19,7 @@ import numpy as np
import pytest
from lerobot.datasets.compute_stats import (
RunningQuantileStats,
_assert_type_and_shape,
aggregate_feature_stats,
aggregate_stats,
@@ -102,6 +103,9 @@ def test_get_feature_stats_axis_1(sample_array):
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=(1,), keepdims=False)
# Check that basic stats are correct (quantiles are also included now)
assert set(expected.keys()).issubset(set(result.keys()))
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
@@ -115,6 +119,9 @@ def test_get_feature_stats_no_axis(sample_array):
"count": np.array([3]),
}
result = get_feature_stats(sample_array, axis=None, keepdims=False)
# Check that basic stats are correct (quantiles are also included now)
assert set(expected.keys()).issubset(set(result.keys()))
for key in expected:
np.testing.assert_allclose(result[key], expected[key])
@@ -308,3 +315,520 @@ def test_aggregate_stats():
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])
def test_running_quantile_stats_initialization():
"""Test proper initialization of RunningQuantileStats."""
running_stats = RunningQuantileStats()
assert running_stats._count == 0
assert running_stats._mean is None
assert running_stats._num_quantile_bins == 5000
# Test custom bin size
running_stats_custom = RunningQuantileStats(num_quantile_bins=1000)
assert running_stats_custom._num_quantile_bins == 1000
def test_running_quantile_stats_single_batch_update():
"""Test updating with a single batch."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 3))
running_stats = RunningQuantileStats()
running_stats.update(data)
assert running_stats._count == 100
assert running_stats._mean.shape == (3,)
assert len(running_stats._histograms) == 3
assert len(running_stats._bin_edges) == 3
# Verify basic statistics are reasonable
np.testing.assert_allclose(running_stats._mean, np.mean(data, axis=0), atol=1e-10)
def test_running_quantile_stats_multiple_batch_updates():
"""Test updating with multiple batches."""
np.random.seed(42)
data1 = np.random.normal(0, 1, (100, 2))
data2 = np.random.normal(1, 1, (150, 2))
running_stats = RunningQuantileStats()
running_stats.update(data1)
running_stats.update(data2)
assert running_stats._count == 250
# Verify running mean is correct
combined_data = np.vstack([data1, data2])
expected_mean = np.mean(combined_data, axis=0)
np.testing.assert_allclose(running_stats._mean, expected_mean, atol=1e-10)
def test_running_quantile_stats_get_statistics_basic():
"""Test getting basic statistics without quantiles."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 2))
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# Should have basic stats
expected_keys = {"min", "max", "mean", "std", "count"}
assert expected_keys.issubset(set(stats.keys()))
# Verify values
np.testing.assert_allclose(stats["mean"], np.mean(data, axis=0), atol=1e-10)
np.testing.assert_allclose(stats["std"], np.std(data, axis=0), atol=1e-6)
np.testing.assert_equal(stats["count"], np.array([100]))
def test_running_quantile_stats_get_statistics_with_quantiles():
"""Test getting statistics with quantiles."""
np.random.seed(42)
data = np.random.normal(0, 1, (1000, 2))
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# Should have basic stats plus quantiles
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert expected_keys.issubset(set(stats.keys()))
# Verify quantile values are reasonable
from lerobot.datasets.compute_stats import DEFAULT_QUANTILES
for i, q in enumerate(DEFAULT_QUANTILES):
q_key = f"q{int(q * 100):02d}"
assert q_key in stats
assert stats[q_key].shape == (2,)
# Check that quantiles are in reasonable order
if i > 0:
prev_q_key = f"q{int(DEFAULT_QUANTILES[i - 1] * 100):02d}"
assert np.all(stats[prev_q_key] <= stats[q_key])
def test_running_quantile_stats_histogram_adjustment():
"""Test that histograms adjust when min/max change."""
running_stats = RunningQuantileStats()
# Initial data with small range
data1 = np.array([[0.0, 1.0], [0.1, 1.1], [0.2, 1.2]])
running_stats.update(data1)
initial_edges_0 = running_stats._bin_edges[0].copy()
initial_edges_1 = running_stats._bin_edges[1].copy()
# Add data with much larger range
data2 = np.array([[10.0, -10.0], [11.0, -11.0]])
running_stats.update(data2)
# Bin edges should have changed
assert not np.array_equal(initial_edges_0, running_stats._bin_edges[0])
assert not np.array_equal(initial_edges_1, running_stats._bin_edges[1])
# New edges should cover the expanded range
# First dimension: min should still be ~0.0, max should be ~11.0
assert running_stats._bin_edges[0][0] <= 0.0
assert running_stats._bin_edges[0][-1] >= 11.0
# Second dimension: min should be ~-11.0, max should be ~1.2
assert running_stats._bin_edges[1][0] <= -11.0
assert running_stats._bin_edges[1][-1] >= 1.2
def test_running_quantile_stats_insufficient_data_error():
"""Test error when trying to get stats with insufficient data."""
running_stats = RunningQuantileStats()
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
running_stats.get_statistics()
# Single vector should also fail
running_stats.update(np.array([[1.0]]))
with pytest.raises(ValueError, match="Cannot compute statistics for less than 2 vectors"):
running_stats.get_statistics()
def test_running_quantile_stats_vector_length_consistency():
"""Test error when vector lengths don't match."""
running_stats = RunningQuantileStats()
running_stats.update(np.array([[1.0, 2.0], [3.0, 4.0]]))
with pytest.raises(ValueError, match="The length of new vectors does not match"):
running_stats.update(np.array([[1.0, 2.0, 3.0]])) # Different length
def test_running_quantile_stats_reshape_handling():
"""Test that various input shapes are handled correctly."""
running_stats = RunningQuantileStats()
# Test 3D input (e.g., images)
data_3d = np.random.normal(0, 1, (10, 32, 32))
running_stats.update(data_3d)
assert running_stats._count == 10 * 32
assert running_stats._mean.shape == (32,)
# Test 1D input
running_stats_1d = RunningQuantileStats()
data_1d = np.array([1, 2, 3, 4, 5]).reshape(-1, 1)
running_stats_1d.update(data_1d)
assert running_stats_1d._count == 5
assert running_stats_1d._mean.shape == (1,)
def test_get_feature_stats_quantiles_enabled_by_default():
"""Test that quantiles are computed by default."""
data = np.random.normal(0, 1, (100, 5))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
def test_get_feature_stats_quantiles_with_vector_data():
"""Test quantile computation with vector data."""
np.random.seed(42)
data = np.random.normal(0, 1, (100, 5))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
# Verify shapes
assert stats["q01"].shape == (5,)
assert stats["q99"].shape == (5,)
# Verify quantiles are reasonable
assert np.all(stats["q01"] < stats["q99"])
def test_get_feature_stats_quantiles_with_image_data():
"""Test quantile computation with image data."""
np.random.seed(42)
data = np.random.normal(0, 1, (50, 3, 32, 32)) # batch, channels, height, width
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats.keys()) == expected_keys
# Verify shapes for images (should be (1, channels, 1, 1))
assert stats["q01"].shape == (1, 3, 1, 1)
assert stats["q50"].shape == (1, 3, 1, 1)
assert stats["q99"].shape == (1, 3, 1, 1)
def test_get_feature_stats_fixed_quantiles():
"""Test that fixed quantiles are always computed."""
data = np.random.normal(0, 1, (200, 3))
stats = get_feature_stats(data, axis=0, keepdims=False)
expected_quantile_keys = {"q01", "q10", "q50", "q90", "q99"}
assert expected_quantile_keys.issubset(set(stats.keys()))
def test_get_feature_stats_unsupported_axis_error():
"""Test error for unsupported axis configuration."""
data = np.random.normal(0, 1, (10, 5))
with pytest.raises(ValueError, match="Unsupported axis configuration"):
get_feature_stats(
data,
axis=(1, 2), # Unsupported axis
keepdims=False,
)
def test_compute_episode_stats_backward_compatibility():
"""Test that existing functionality is preserved."""
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(0, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
stats = compute_episode_stats(episode_data, features)
for key in ["action", "observation.state"]:
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats[key].keys()) == expected_keys
def test_compute_episode_stats_with_custom_quantiles():
"""Test quantile computation with custom quantile values."""
np.random.seed(42)
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(2, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
stats = compute_episode_stats(episode_data, features)
# Should have quantiles
for key in ["action", "observation.state"]:
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(stats[key].keys()) == expected_keys
# Verify shapes
assert stats[key]["q01"].shape == (features[key]["shape"][0],)
assert stats[key]["q99"].shape == (features[key]["shape"][0],)
def test_compute_episode_stats_with_image_data():
"""Test quantile computation with image features."""
image_paths = [f"image_{i}.jpg" for i in range(50)]
episode_data = {
"observation.image": image_paths,
"action": np.random.normal(0, 1, (50, 5)),
}
features = {
"observation.image": {"dtype": "image"},
"action": {"dtype": "float32", "shape": (5,)},
}
with patch("lerobot.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy):
stats = compute_episode_stats(episode_data, features)
# Image quantiles should be normalized and have correct shape
assert "q01" in stats["observation.image"]
assert "q50" in stats["observation.image"]
assert "q99" in stats["observation.image"]
assert stats["observation.image"]["q01"].shape == (3, 1, 1)
assert stats["observation.image"]["q50"].shape == (3, 1, 1)
assert stats["observation.image"]["q99"].shape == (3, 1, 1)
# Action quantiles should have correct shape
assert stats["action"]["q01"].shape == (5,)
assert stats["action"]["q50"].shape == (5,)
assert stats["action"]["q99"].shape == (5,)
def test_compute_episode_stats_string_features_skipped():
"""Test that string features are properly skipped."""
episode_data = {
"task": ["pick_apple"] * 100, # String feature
"action": np.random.normal(0, 1, (100, 5)),
}
features = {
"task": {"dtype": "string"},
"action": {"dtype": "float32", "shape": (5,)},
}
stats = compute_episode_stats(
episode_data,
features,
)
# String features should be skipped
assert "task" not in stats
assert "action" in stats
assert "q01" in stats["action"]
def test_aggregate_feature_stats_with_quantiles():
"""Test aggregating feature stats that include quantiles."""
stats_ft_list = [
{
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([100]),
"q01": np.array([1.5]),
"q99": np.array([9.5]),
},
{
"min": np.array([2.0]),
"max": np.array([12.0]),
"mean": np.array([6.0]),
"std": np.array([2.5]),
"count": np.array([150]),
"q01": np.array([2.5]),
"q99": np.array([11.5]),
},
]
result = aggregate_feature_stats(stats_ft_list)
# Should preserve quantiles
assert "q01" in result
assert "q99" in result
# Verify quantile aggregation (weighted average)
expected_q01 = (1.5 * 100 + 2.5 * 150) / 250 # ≈ 2.1
expected_q99 = (9.5 * 100 + 11.5 * 150) / 250 # ≈ 10.7
np.testing.assert_allclose(result["q01"], np.array([expected_q01]), atol=1e-6)
np.testing.assert_allclose(result["q99"], np.array([expected_q99]), atol=1e-6)
def test_aggregate_stats_mixed_quantiles():
"""Test aggregating stats where some have quantiles and some don't."""
stats_with_quantiles = {
"feature1": {
"min": np.array([1.0]),
"max": np.array([10.0]),
"mean": np.array([5.0]),
"std": np.array([2.0]),
"count": np.array([100]),
"q01": np.array([1.5]),
"q99": np.array([9.5]),
}
}
stats_without_quantiles = {
"feature2": {
"min": np.array([0.0]),
"max": np.array([5.0]),
"mean": np.array([2.5]),
"std": np.array([1.5]),
"count": np.array([50]),
}
}
all_stats = [stats_with_quantiles, stats_without_quantiles]
result = aggregate_stats(all_stats)
# Feature1 should keep its quantiles
assert "q01" in result["feature1"]
assert "q99" in result["feature1"]
# Feature2 should not have quantiles
assert "q01" not in result["feature2"]
assert "q99" not in result["feature2"]
def test_assert_type_and_shape_with_quantiles():
"""Test validation works correctly with quantile keys."""
# Valid stats with quantiles
valid_stats = [
{
"observation.image": {
"min": np.array([0.0, 0.0, 0.0]).reshape(3, 1, 1),
"max": np.array([1.0, 1.0, 1.0]).reshape(3, 1, 1),
"mean": np.array([0.5, 0.5, 0.5]).reshape(3, 1, 1),
"std": np.array([0.2, 0.2, 0.2]).reshape(3, 1, 1),
"count": np.array([100]),
"q01": np.array([0.1, 0.1, 0.1]).reshape(3, 1, 1),
"q99": np.array([0.9, 0.9, 0.9]).reshape(3, 1, 1),
}
}
]
# Should not raise error
_assert_type_and_shape(valid_stats)
# Invalid shape for quantile
invalid_stats = [
{
"observation.image": {
"count": np.array([100]),
"q01": np.array([0.1, 0.2]), # Wrong shape for image quantile
}
}
]
with pytest.raises(ValueError, match="Shape of quantile 'q01' must be \\(3,1,1\\)"):
_assert_type_and_shape(invalid_stats)
def test_quantile_integration_single_value_quantiles():
"""Test quantile computation with single repeated value."""
data = np.ones((100, 3)) # All ones
running_stats = RunningQuantileStats()
running_stats.update(data)
stats = running_stats.get_statistics()
# All quantiles should be approximately 1.0
np.testing.assert_allclose(stats["q01"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
np.testing.assert_allclose(stats["q50"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
np.testing.assert_allclose(stats["q99"], np.array([1.0, 1.0, 1.0]), atol=1e-6)
def test_quantile_integration_fixed_quantiles():
"""Test that fixed quantiles are computed."""
np.random.seed(42)
data = np.random.normal(0, 1, (1000, 2))
stats = get_feature_stats(data, axis=0, keepdims=False)
# Check all fixed quantiles are present
assert "q01" in stats
assert "q10" in stats
assert "q50" in stats
assert "q90" in stats
assert "q99" in stats
def test_quantile_integration_large_dataset_quantiles():
"""Test quantile computation efficiency with large datasets."""
np.random.seed(42)
large_data = np.random.normal(0, 1, (10000, 5))
running_stats = RunningQuantileStats(num_quantile_bins=1000) # Reduced bins for speed
running_stats.update(large_data)
stats = running_stats.get_statistics()
# Should complete without issues and produce reasonable results
assert stats["count"][0] == 10000
assert len(stats["q01"]) == 5
def test_fixed_quantiles_always_computed():
"""Test that the fixed quantiles [0.01, 0.10, 0.50, 0.90, 0.99] are always computed."""
np.random.seed(42)
# Test with vector data
vector_data = np.random.normal(0, 1, (100, 5))
vector_stats = get_feature_stats(vector_data, axis=0, keepdims=False)
# Check all fixed quantiles are present
expected_quantiles = ["q01", "q10", "q50", "q90", "q99"]
for q_key in expected_quantiles:
assert q_key in vector_stats
assert vector_stats[q_key].shape == (5,)
# Test with image data
image_data = np.random.randint(0, 256, (50, 3, 32, 32), dtype=np.uint8)
image_stats = get_feature_stats(image_data, axis=(0, 2, 3), keepdims=True)
# Check all fixed quantiles are present for images
for q_key in expected_quantiles:
assert q_key in image_stats
assert image_stats[q_key].shape == (1, 3, 1, 1)
# Test with episode data
episode_data = {
"action": np.random.normal(0, 1, (100, 7)),
"observation.state": np.random.normal(0, 1, (100, 10)),
}
features = {
"action": {"dtype": "float32", "shape": (7,)},
"observation.state": {"dtype": "float32", "shape": (10,)},
}
episode_stats = compute_episode_stats(episode_data, features)
# Check all fixed quantiles are present in episode stats
for key in ["action", "observation.state"]:
for q_key in expected_quantiles:
assert q_key in episode_stats[key]
assert episode_stats[key][q_key].shape == (features[key]["shape"][0],)

View File

@@ -0,0 +1,212 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for quantile functionality in LeRobotDataset."""
import numpy as np
import pytest
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def mock_load_image_as_numpy(path, dtype, channel_first):
"""Mock image loading for consistent test results."""
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
@pytest.fixture
def simple_features():
"""Simple feature configuration for testing."""
return {
"action": {
"dtype": "float32",
"shape": (4,),
"names": ["arm_x", "arm_y", "arm_z", "gripper"],
},
"observation.state": {
"dtype": "float32",
"shape": (10,),
"names": [f"joint_{i}" for i in range(10)],
},
}
def test_create_dataset_with_fixed_quantiles(tmp_path, simple_features):
"""Test creating dataset with fixed quantiles."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_fixed_quantiles",
fps=30,
features=simple_features,
root=tmp_path / "create_fixed_quantiles",
)
# Dataset should be created successfully
assert dataset is not None
def test_save_episode_computes_all_quantiles(tmp_path, simple_features):
"""Test that all fixed quantiles are computed when saving an episode."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_save_episode",
fps=30,
features=simple_features,
root=tmp_path / "save_episode_quantiles",
)
# Add some frames
for _ in range(10):
dataset.add_frame(
{
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
"observation.state": np.random.randn(10).astype(np.float32),
"task": "test_task",
}
)
dataset.save_episode()
# Check that all fixed quantiles were computed
stats = dataset.meta.stats
for key in ["action", "observation.state"]:
assert "q01" in stats[key]
assert "q10" in stats[key]
assert "q50" in stats[key]
assert "q90" in stats[key]
assert "q99" in stats[key]
def test_quantile_values_ordering(tmp_path, simple_features):
"""Test that quantile values are properly ordered."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_quantile_ordering",
fps=30,
features=simple_features,
root=tmp_path / "quantile_ordering",
)
# Add data with known distribution
np.random.seed(42)
for _ in range(100):
dataset.add_frame(
{
"action": np.random.randn(4).astype(np.float32), # Correct shape for action
"observation.state": np.random.randn(10).astype(np.float32),
"task": "test_task",
}
)
dataset.save_episode()
stats = dataset.meta.stats
# Verify quantile ordering
for key in ["action", "observation.state"]:
assert np.all(stats[key]["q01"] <= stats[key]["q10"])
assert np.all(stats[key]["q10"] <= stats[key]["q50"])
assert np.all(stats[key]["q50"] <= stats[key]["q90"])
assert np.all(stats[key]["q90"] <= stats[key]["q99"])
def test_save_episode_with_fixed_quantiles(tmp_path, simple_features):
"""Test saving episode always computes fixed quantiles."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_save_fixed",
fps=30,
features=simple_features,
root=tmp_path / "save_fixed_quantiles",
)
# Add frames to episode
np.random.seed(42)
for _ in range(50):
frame = {
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
"observation.state": np.random.normal(0, 1, (10,)).astype(np.float32),
"task": "test_task",
}
dataset.add_frame(frame)
dataset.save_episode()
# Check that all fixed quantiles are included
stats = dataset.meta.stats
for key in ["action", "observation.state"]:
feature_stats = stats[key]
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(feature_stats.keys()) == expected_keys
def test_quantile_aggregation_across_episodes(tmp_path, simple_features):
"""Test quantile aggregation across multiple episodes."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_aggregation",
fps=30,
features=simple_features,
root=tmp_path / "quantile_aggregation",
)
# Add frames to episode
np.random.seed(42)
for _ in range(100):
frame = {
"action": np.random.normal(0, 1, (4,)).astype(np.float32),
"observation.state": np.random.normal(2, 1, (10,)).astype(np.float32),
"task": "test_task",
}
dataset.add_frame(frame)
dataset.save_episode()
# Check stats include all fixed quantiles
stats = dataset.meta.stats
for key in ["action", "observation.state"]:
feature_stats = stats[key]
expected_keys = {"min", "max", "mean", "std", "count", "q01", "q10", "q50", "q90", "q99"}
assert set(feature_stats.keys()) == expected_keys
assert feature_stats["q01"].shape == (simple_features[key]["shape"][0],)
assert feature_stats["q50"].shape == (simple_features[key]["shape"][0],)
assert feature_stats["q99"].shape == (simple_features[key]["shape"][0],)
assert np.all(feature_stats["q01"] <= feature_stats["q50"])
assert np.all(feature_stats["q50"] <= feature_stats["q99"])
def test_save_multiple_episodes_with_quantiles(tmp_path, simple_features):
"""Test quantile aggregation across multiple episodes."""
dataset = LeRobotDataset.create(
repo_id="test_dataset_multiple_episodes",
fps=30,
features=simple_features,
root=tmp_path / "multiple_episodes",
)
# Save multiple episodes
np.random.seed(42)
for episode_idx in range(3):
for _ in range(50):
frame = {
"action": np.random.normal(episode_idx * 2.0, 1, (4,)).astype(np.float32),
"observation.state": np.random.normal(-episode_idx * 1.5, 1, (10,)).astype(np.float32),
"task": f"task_{episode_idx}",
}
dataset.add_frame(frame)
dataset.save_episode()
# Verify final stats include properly aggregated quantiles
stats = dataset.meta.stats
for key in ["action", "observation.state"]:
feature_stats = stats[key]
assert "q01" in feature_stats and "q99" in feature_stats
assert feature_stats["count"][0] == 150 # 3 episodes * 50 frames

View File

@@ -0,0 +1,117 @@
#!/usr/bin/env python
"""Test script to verify PI0 policy integration with LeRobot, only meant to be run locally!"""
import os
import pytest
import torch
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0 import ( # noqa: E402
PI0Config,
PI0Policy,
make_pi0_pre_post_processors, # noqa: E402
)
from lerobot.utils.random_utils import set_seed # noqa: E402
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_policy_instantiation():
# Create config
set_seed(42)
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = PI0Policy(config)
preprocessor, postprocessor = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
# Test forward pass with dummy data
batch_size = 1
device = config.device
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
try:
with torch.no_grad():
action = policy.select_action(batch)
action = postprocessor(action)
print(f"Action: {action}")
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
raise
@require_cuda
def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="pi0",
max_action_dim=7,
max_state_dim=14,
)
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}")
except Exception as e:
print(f"Config creation failed: {e}")
raise

View File

@@ -0,0 +1,154 @@
#!/usr/bin/env python
"""Test script to verify PI0.5 (pi05) support in PI0 policy, only meant to be run locally!"""
import os
import pytest
import torch
from lerobot.utils.random_utils import set_seed
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi05 import ( # noqa: E402
PI05Config,
PI05Policy,
make_pi05_pre_post_processors, # noqa: E402
)
from tests.utils import require_cuda # noqa: E402
@require_cuda
def test_policy_instantiation():
# Create config
set_seed(42)
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
# Set up input_features and output_features in the config
from lerobot.configs.types import FeatureType, PolicyFeature
config.input_features = {
"observation.state": PolicyFeature(
type=FeatureType.STATE,
shape=(14,),
),
"observation.images.base_0_rgb": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224),
),
}
config.output_features = {
"action": PolicyFeature(
type=FeatureType.ACTION,
shape=(7,),
),
}
assert config.tokenizer_max_length == 200, (
f"Expected tokenizer_max_length=200 for pi05, got {config.tokenizer_max_length}"
)
# Create dummy dataset stats
dataset_stats = {
"observation.state": {
"mean": torch.zeros(14),
"std": torch.ones(14),
"min": torch.zeros(14),
"max": torch.ones(14),
"q01": torch.zeros(14),
"q99": torch.ones(14),
},
"action": {
"mean": torch.zeros(7),
"std": torch.ones(7),
"min": torch.zeros(7),
"max": torch.ones(7),
"q01": torch.zeros(7),
"q99": torch.ones(7),
},
"observation.images.base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
}
# Instantiate policy
policy = PI05Policy(config)
# Test forward pass with dummy data
batch_size = 1
preprocessor, postprocessor = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
device = config.device
batch = {
"observation.state": torch.randn(batch_size, 14, dtype=torch.float32, device=device),
"action": torch.randn(batch_size, config.chunk_size, 7, dtype=torch.float32, device=device),
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
), # Use rand for [0,1] range
"task": ["Pick up the object"] * batch_size,
}
batch = preprocessor(batch)
try:
loss, loss_dict = policy.forward(batch)
print(f"Forward pass successful. Loss: {loss_dict['loss']:.4f}")
except Exception as e:
print(f"Forward pass failed: {e}")
raise
try:
with torch.no_grad():
action = policy.select_action(batch)
action = postprocessor(action)
print(f"Action: {action}")
print(f"Action prediction successful. Action shape: {action.shape}")
except Exception as e:
print(f"Action prediction failed: {e}")
raise
# Verify pi05 model components exist
# Check that time_mlp layers exist (for AdaRMS conditioning)
assert hasattr(policy.model, "time_mlp_in"), "Missing time_mlp_in layer for pi05"
assert hasattr(policy.model, "time_mlp_out"), "Missing time_mlp_out layer for pi05"
# Check that action_time_mlp layers don't exist (pi0 only)
assert not hasattr(policy.model, "action_time_mlp_in"), "action_time_mlp_in should not exist in pi05 mode"
assert not hasattr(policy.model, "action_time_mlp_out"), (
"action_time_mlp_out should not exist in pi05 mode"
)
# Check that state_proj doesn't exist in pi05 mode
assert not hasattr(policy.model, "state_proj"), "state_proj should not exist in pi05 mode"
# Check AdaRMS configuration in the underlying model
adarms_config = policy.model.paligemma_with_expert.paligemma.config.text_config.use_adarms
assert adarms_config == False, f"PaliGemma should not use AdaRMS, got {adarms_config}" # noqa: E712
adarms_expert_config = policy.model.paligemma_with_expert.gemma_expert.config.use_adarms
assert adarms_expert_config == True, ( # noqa: E712
f"Action expert should use AdaRMS in pi05, got {adarms_expert_config}"
)
@require_cuda
def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="pi0",
max_action_dim=7,
max_state_dim=14,
)
print("Config created successfully through factory")
print(f" Config type: {type(config).__name__}")
print(f" PaliGemma variant: {config.paligemma_variant}")
print(f" Action expert variant: {config.action_expert_variant}")
except Exception as e:
print(f"Config creation failed: {e}")
raise

View File

@@ -0,0 +1,419 @@
"""Test script to verify PI0OpenPI policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os
from copy import deepcopy
from typing import Any
import numpy as np
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
"base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
"left_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
"right_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
},
}
class PI05BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
precision: str = "float32"
pi05: bool = True
dtype: str = "float32"
def instantiate_lerobot_pi05(
from_pretrained: bool = False,
) -> tuple[
PI05Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI05Policy.from_pretrained(pretrained_name_or_path="lerobot/pi05_base", strict=True)
else:
config = PI05Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI05Policy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi05_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
def instantiate_original_pi05(from_pretrained: bool = False, model_path: str | None = None):
config = PI05BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi05_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi05_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
class PI05Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing with pi05 state tokenizer."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description (pi05 processor handles all text formatting)
tasks = batch.get("task", ["Pick up the object"] * batch_size)
if isinstance(tasks, str):
tasks = [tasks] * batch_size
elif len(tasks) == 1:
tasks = tasks * batch_size
# Use pi05 state and input tokenizer logic (same as Pi05PrepareStateTokenizerProcessorStep)
state = batch["observation.state"]
state = deepcopy(state)
# Prepare state (pad to max_state_dim)
from lerobot.policies.pi05.modeling_pi05 import pad_vector
state = pad_vector(state, DUMMY_STATE_DIM)
# Normalize state to [-1, 1] range if needed (assuming it's already normalized from normalize_inputs)
# Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`)
state_np = state.cpu().numpy()
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
# Create pi05-formatted prompts that include state information
full_prompts = []
for i, task in enumerate(tasks):
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
state_str = " ".join(map(str, discretized_states[i]))
full_prompt = f"Task: {cleaned_text}, State: {state_str};\nAction: "
full_prompts.append(full_prompt)
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
full_prompts,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI05Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi05_original_vs_lerobot():
"""Test PI05 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi05, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi05(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi05(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi05.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
with torch.no_grad():
lerobot_actions_own = lerobot_pi05.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4

View File

@@ -0,0 +1,410 @@
"""Test script to verify PI0 policy integration with LeRobot vs the original implementation, only meant to be run locally!"""
import os
from copy import deepcopy
from typing import Any
import pytest
import torch
# Skip if openpi or transformers is not available
pytest.importorskip("openpi")
pytest.importorskip("transformers")
# Skip this entire module in CI
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="This test requires local OpenPI installation and is not meant for CI",
)
from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing # noqa: E402
# NOTE: Assumes PYTHONPATH is set to include OpenPI src as per instructions.
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors # noqa: E402
from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402
# TODO: ADDING DEFAULT IMAGES_FEATURES TO CONFIG
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 48 # Default for PI0 (non-pi05)
DEVICE = "cpu" # Use CPU to avoid memory issues for testing
DUMMY_DATASET_STATS = {
"observation.state": {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
"action": {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
"base_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
"left_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
"right_wrist_0_rgb": {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
},
},
}
class PI0BaseOriginalConfig:
action_dim: int = DUMMY_ACTION_DIM
action_horizon: int = DUMMY_ACTION_HORIZON
paligemma_variant: str = "gemma_2b"
action_expert_variant: str = "gemma_300m"
precision: str = "float32"
pi05: bool = False
dtype: str = "float32"
def instantiate_lerobot_pi0(
from_pretrained: bool = False,
) -> tuple[
PI0Policy,
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
if from_pretrained:
# Load the policy first
policy = PI0Policy.from_pretrained(pretrained_name_or_path="lerobot/pi0_base", strict=True)
else:
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0Policy(config)
policy.to(DEVICE)
policy.config.device = DEVICE
preprocessor, postprocessor = make_pi0_pre_post_processors(
config=policy.config, dataset_stats=DUMMY_DATASET_STATS
)
return (policy, preprocessor, postprocessor)
def instantiate_original_pi0(from_pretrained: bool = False, model_path: str = None):
config = PI0BaseOriginalConfig()
policy = PI0Pytorch(config)
if from_pretrained:
try:
print("Loading converted PyTorch weights from HuggingFace Hub (lerobot/pi0_base)...")
# Download the model from HuggingFace Hub
import safetensors.torch
from huggingface_hub import snapshot_download
# Download the entire repository
if model_path and os.path.exists(model_path):
cache_dir = model_path
print(f"Using cached model from: {cache_dir}")
else:
cache_dir = snapshot_download(repo_id="lerobot/pi0_base", repo_type="model")
print(f"Downloaded model to: {cache_dir}")
# Try to load safetensors format first
model_file = os.path.join(cache_dir, "model.safetensors")
if os.path.exists(model_file):
state_dict = safetensors.torch.load_file(model_file)
print(f"Loaded {len(state_dict)} parameters from safetensors")
else:
raise FileNotFoundError(f"No safetensors file found in {cache_dir}")
# Load the state dict into the model
missing_keys, unexpected_keys = policy.load_state_dict(state_dict, strict=False)
if missing_keys:
print(f"Missing keys: {len(missing_keys)}")
if len(missing_keys) <= 5:
for key in missing_keys:
print(f" - {key}")
else:
for key in missing_keys[:5]:
print(f" - {key}")
print(f" ... and {len(missing_keys) - 5} more")
if unexpected_keys:
print(f"Unexpected keys: {len(unexpected_keys)}")
if len(unexpected_keys) <= 5:
for key in unexpected_keys:
print(f" - {key}")
else:
for key in unexpected_keys[:5]:
print(f" - {key}")
print(f" ... and {len(unexpected_keys) - 5} more")
if not missing_keys and not unexpected_keys:
print("All pretrained weights loaded successfully!")
else:
print("Pretrained weights loaded with some missing/unexpected keys (this may be normal)")
except Exception as e:
print(f"Failed to load pretrained weights: {e}")
print(" Using randomly initialized weights...")
import traceback
traceback.print_exc()
policy.to(DEVICE)
return policy
def create_dummy_data():
batch_size = 2 # Reduce batch size for testing
device = DEVICE
# Use the exact same prompt for both implementations
prompt = "Pick up the red block and place it in the bin"
batch = {
"observation.state": torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device),
"action": torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=device
),
# Create images in [0, 1] range as expected by LeRobot (will be converted to [-1, 1] internally)
"observation.images.base_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.left_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
"observation.images.right_wrist_0_rgb": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=device
),
# Add the task prompt for LeRobot - provide as list with single element to trigger expansion
"task": [prompt for _ in range(batch_size)],
}
return batch
def extract_lerobot_processed_inputs(lerobot_pi0, batch):
"""Extract the exact same processed inputs that LeRobot uses internally."""
# Get the tokenized language from LeRobot's internal method
lang_tokens, lang_masks = lerobot_pi0._tokenize_language(batch)
# Get the preprocessed images from LeRobot's internal method
images, img_masks = lerobot_pi0._preprocess_images(batch, train=False)
# Create dummy token_ar_mask and token_loss_mask for original implementation
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
return images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask
class PI0Observation:
"""Observation class that matches the original OpenPI format."""
def __init__(
self,
state,
images,
image_masks,
tokenized_prompt,
tokenized_prompt_mask,
token_ar_mask,
token_loss_mask,
):
self.state = state
self.images = images
self.image_masks = image_masks
self.tokenized_prompt = tokenized_prompt
self.tokenized_prompt_mask = tokenized_prompt_mask
self.token_ar_mask = token_ar_mask
self.token_loss_mask = token_loss_mask
def create_original_observation_with_openpi_preprocessing(batch):
"""Create observation object for OpenPI using OpenPI's own preprocessing."""
batch_size = batch["observation.state"].shape[0]
device = batch["observation.state"].device
# Create tokenizer for OpenPI (same as LeRobot uses)
tokenizer = AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")
# Get task description
if "task" in batch:
tasks = batch["task"]
if isinstance(tasks, str):
# Single string: add newline if not present, then convert to list
if not tasks.endswith("\n"):
tasks = f"{tasks}\n"
tasks = [tasks]
elif isinstance(tasks, list) and all(isinstance(t, str) for t in tasks):
# List of strings: add newline to each if not present
tasks = [t if t.endswith("\n") else f"{t}\n" for t in tasks]
if len(tasks) == 1:
# Expand to batch size
tasks = tasks * batch_size
if len(tasks) != batch_size:
raise ValueError(f"Expected batch size {batch_size}, got {len(tasks)}")
# If task is neither string nor list of strings, leave unchanged
else:
# Default task if not provided
tasks = ["Pick up the object\n"] * batch_size
# Tokenize with max_length padding to match OpenPI's expected format
tokenized = tokenizer(
tasks,
padding="max_length",
padding_side="right",
truncation=True,
max_length=DUMMY_MAX_TOKEN_LEN,
return_tensors="pt",
)
lang_tokens = tokenized["input_ids"].to(device)
lang_masks = tokenized["attention_mask"].to(device, dtype=torch.bool)
# Create dummy token_ar_mask and token_loss_mask for OpenPI
token_ar_mask = torch.zeros_like(lang_tokens, dtype=torch.int32)
token_loss_mask = torch.ones_like(lang_masks, dtype=torch.bool)
# Convert LeRobot images format to OpenPI format (convert [0,1] to [-1,1] range)
image_dict = {
"base_0_rgb": batch["observation.images.base_0_rgb"] * 2.0 - 1.0,
"left_wrist_0_rgb": batch["observation.images.left_wrist_0_rgb"] * 2.0 - 1.0,
"right_wrist_0_rgb": batch["observation.images.right_wrist_0_rgb"] * 2.0 - 1.0,
}
# Create image masks (all ones for real images)
image_masks_dict = {}
for key in image_dict:
image_masks_dict[key] = torch.ones(batch_size, dtype=torch.bool, device=device)
# Create raw observation object (before preprocessing)
raw_observation = PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
# Now use OpenPI's preprocessing
processed_obs = openpi_preprocessing.preprocess_observation_pytorch(raw_observation, train=False)
return processed_obs
def create_original_observation_from_lerobot(lerobot_pi0, batch):
"""Create observation object compatible with original OpenPI using the exact same inputs as LeRobot."""
_batch_size = batch["observation.state"].shape[0]
_device = batch["observation.state"].device
# Extract the exact same processed inputs that LeRobot uses
images, img_masks, lang_tokens, lang_masks, token_ar_mask, token_loss_mask = (
extract_lerobot_processed_inputs(lerobot_pi0, batch)
)
# Convert images list to dict with original OpenPI keys
image_dict = {
"base_0_rgb": images[0],
"left_wrist_0_rgb": images[1],
"right_wrist_0_rgb": images[2],
}
# Convert image masks list to dict with original OpenPI keys
image_masks_dict = {
"base_0_rgb": img_masks[0],
"left_wrist_0_rgb": img_masks[1],
"right_wrist_0_rgb": img_masks[2],
}
return PI0Observation(
state=batch["observation.state"],
images=image_dict,
image_masks=image_masks_dict,
tokenized_prompt=lang_tokens,
tokenized_prompt_mask=lang_masks,
token_ar_mask=token_ar_mask,
token_loss_mask=token_loss_mask,
)
def test_pi0_original_vs_lerobot():
"""Test PI0 original implementation vs LeRobot implementation."""
print("Initializing models...")
lerobot_pi0, lerobot_preprocessor, lerobot_postprocessor = instantiate_lerobot_pi0(
from_pretrained=True
) # Load pretrained LeRobot model
original_pi0 = instantiate_original_pi0(
from_pretrained=True
) # Load pretrained OpenPI model from HuggingFace Hub
print("Creating dummy data...")
batch = create_dummy_data()
batch_lerobot = deepcopy(batch)
# Test each model with its own preprocessing (more realistic end-to-end test)
print("\nTest each model with its own preprocessing")
print("Creating observation for OpenPI using OpenPI's own preprocessing...")
pi0_obs_openpi = create_original_observation_with_openpi_preprocessing(batch)
print(f"Task prompt: '{batch['task'][0]}'")
print(f"OpenPI tokenized prompt shape: {pi0_obs_openpi.tokenized_prompt.shape}")
print(f"OpenPI image shapes: {[img.shape for img in pi0_obs_openpi.images.values()]}")
print(f"OpenPI state shape: {pi0_obs_openpi.state.shape}")
print("Testing OpenPI with own preprocessing...")
original_pi0.eval()
torch.manual_seed(42) # Set seed for reproducibility
batch_size = batch["observation.state"].shape[0]
noise_shape = (batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM)
fixed_noise = torch.randn(noise_shape, dtype=torch.float32, device=DEVICE)
with torch.no_grad():
openpi_actions = original_pi0.sample_actions(
device=DEVICE, observation=pi0_obs_openpi, noise=fixed_noise, num_steps=10
)
openpi_actions_unit = openpi_actions[:, 0, :]
print(f"OpenPI (own preprocessing) Actions shape: {openpi_actions.shape}")
print(f"OpenPI (own preprocessing) Actions unit shape: {openpi_actions_unit.shape}")
print(f"OpenPI (own preprocessing) Actions mean: {openpi_actions.mean().item():.6f}")
print(f"OpenPI (own preprocessing) Actions std: {openpi_actions.std().item():.6f}")
print("Testing LeRobot with own preprocessing...")
lerobot_pi0.eval()
torch.manual_seed(42) # Set the same seed
batch_lerobot_processed = lerobot_preprocessor(batch_lerobot)
with torch.no_grad():
lerobot_actions_own = lerobot_pi0.predict_action_chunk(
batch_lerobot_processed
) # batch_size, n_action_steps, action_dim
lerobot_actions_unit = lerobot_actions_own[:, 0, :]
print(f"LeRobot (own preprocessing) Actions shape: {lerobot_actions_own.shape}")
print(f"LeRobot (own preprocessing) Actions unit shape: {lerobot_actions_unit.shape}")
print(f"LeRobot (own preprocessing) Actions mean: {lerobot_actions_own.mean().item():.6f}")
print(f"LeRobot (own preprocessing) Actions std: {lerobot_actions_own.std().item():.6f}")
print("\nComparing end-to-end implementations:")
print(f"Actions close (atol=1e-4): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)}")
print(f"Actions close (atol=1e-2): {torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)}")
print(f"Max absolute difference: {torch.abs(lerobot_actions_own - openpi_actions).max().item():.6f}")
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-4)
assert torch.allclose(lerobot_actions_own, openpi_actions, atol=1e-2)
assert torch.abs(lerobot_actions_own - openpi_actions).max().item() < 1e-4

View File

@@ -166,6 +166,226 @@ def test_min_max_normalization(observation_normalizer):
assert torch.allclose(normalized_obs[OBS_STATE], expected_state, atol=1e-6)
def test_quantile_normalization():
"""Test QUANTILES mode using 1st-99th percentiles."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {
"observation.state": {
"q01": np.array([0.1, -0.8]), # 1st percentile
"q99": np.array([0.9, 0.8]), # 99th percentile
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check quantile normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 2 * 0.4 / 0.8 - 1 = 0.0
# For state[1]: 2 * (0.0 - (-0.8)) / (0.8 - (-0.8)) - 1 = 2 * 0.8 / 1.6 - 1 = 0.0
expected_state = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_quantile10_normalization():
"""Test QUANTILE10 mode using 10th-90th percentiles."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILE10,
}
stats = {
"observation.state": {
"q10": np.array([0.2, -0.6]), # 10th percentile
"q90": np.array([0.8, 0.6]), # 90th percentile
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Check quantile normalization to [-1, 1]
# For state[0]: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 2 * 0.3 / 0.6 - 1 = 0.0
# For state[1]: 2 * (0.0 - (-0.6)) / (0.6 - (-0.6)) - 1 = 2 * 0.6 / 1.2 - 1 = 0.0
expected_state = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
def test_quantile_unnormalization():
"""Test that quantile normalization can be reversed properly."""
features = {
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.ACTION: NormalizationMode.QUANTILES,
}
stats = {
"action": {
"q01": np.array([0.1, -0.8]),
"q99": np.array([0.9, 0.8]),
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Test round-trip normalization
original_action = torch.tensor([0.5, 0.0])
transition = create_transition(action=original_action)
# Normalize then unnormalize
normalized = normalizer(transition)
unnormalized = unnormalizer(normalized)
# Should recover original values
recovered_action = unnormalized[TransitionKey.ACTION]
assert torch.allclose(recovered_action, original_action, atol=1e-6)
def test_quantile_division_by_zero():
"""Test quantile normalization handles edge case where q01 == q99."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (1,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {
"observation.state": {
"q01": np.array([0.5]), # Same value
"q99": np.array([0.5]), # Same value -> division by zero case
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5]),
}
transition = create_transition(observation=observation)
# Should not crash and should handle gracefully
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# When quantiles are identical, should normalize to 0 (due to epsilon handling)
assert torch.isfinite(normalized_obs["observation.state"]).all()
def test_quantile_partial_stats():
"""Test that quantile normalization handles missing quantile stats by raising."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
# Missing q99 - should pass through unchanged
stats_partial = {
"observation.state": {
"q01": np.array([0.1, -0.8]), # Only q01, missing q99
},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats_partial)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
with pytest.raises(ValueError, match="QUANTILES normalization mode requires q01 and q99 stats"):
_ = normalizer(transition)
def test_quantile_mixed_with_other_modes():
"""Test quantile normalization mixed with other normalization modes."""
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3,)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD, # Standard normalization
FeatureType.STATE: NormalizationMode.QUANTILES, # Quantile normalization
FeatureType.ACTION: NormalizationMode.QUANTILE10, # Different quantile mode
}
stats = {
"observation.image": {"mean": [0.5, 0.5, 0.5], "std": [0.2, 0.2, 0.2]},
"observation.state": {"q01": [0.1, -0.8], "q99": [0.9, 0.8]},
"action": {"q10": [0.2, -0.6], "q90": [0.8, 0.6]},
}
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]), # Should use QUANTILES
}
action = torch.tensor([0.5, 0.0]) # Should use QUANTILE10
transition = create_transition(observation=observation, action=action)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
normalized_action = normalized_transition[TransitionKey.ACTION]
# Image should be mean/std normalized: (0.7 - 0.5) / 0.2 = 1.0, etc.
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
assert torch.allclose(normalized_obs["observation.image"], expected_image)
# State should be quantile normalized: 2 * (0.5 - 0.1) / (0.9 - 0.1) - 1 = 0.0, etc.
expected_state = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
# Action should be quantile10 normalized: 2 * (0.5 - 0.2) / (0.8 - 0.2) - 1 = 0.0, etc.
expected_action = torch.tensor([0.0, 0.0])
assert torch.allclose(normalized_action, expected_action, atol=1e-6)
def test_quantile_with_missing_stats():
"""Test that quantile normalization handles completely missing stats gracefully."""
features = {
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
}
norm_map = {
FeatureType.STATE: NormalizationMode.QUANTILES,
}
stats = {} # No stats provided
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
observation = {
"observation.state": torch.tensor([0.5, 0.0]),
}
transition = create_transition(observation=observation)
normalized_transition = normalizer(transition)
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
# Should pass through unchanged when no stats available
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
def test_selective_normalization(observation_stats):
features = _create_observation_features()
norm_map = _create_observation_norm_map()
@@ -547,7 +767,7 @@ def test_empty_stats():
def test_partial_stats():
"""If statistics are incomplete, the value should pass through unchanged."""
"""If statistics are incomplete, we should raise."""
stats = {OBS_IMAGE: {"mean": [0.5]}} # Missing std / (min,max)
features = {OBS_IMAGE: PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
@@ -555,8 +775,8 @@ def test_partial_stats():
observation = {OBS_IMAGE: torch.tensor([0.7])}
transition = create_transition(observation=observation)
processed = normalizer(transition)[TransitionKey.OBSERVATION]
assert torch.allclose(processed[OBS_IMAGE], observation[OBS_IMAGE])
with pytest.raises(ValueError, match="MEAN_STD normalization mode requires mean and std stats"):
_ = normalizer(transition)[TransitionKey.OBSERVATION]
def test_missing_action_stats_no_error():

View File

@@ -1,424 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI0 policy processor."""
from unittest.mock import patch
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition, transition_to_batch
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_default_config():
"""Create a default PI0 configuration for testing."""
config = PI0Config()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
}
config.normalization_mapping = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
config.device = "cpu"
config.tokenizer_max_length = 128
return config
def create_default_stats():
"""Create default dataset statistics for testing."""
return {
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
OBS_IMAGE: {}, # No normalization for images
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
}
def test_make_pi0_processor_basic():
"""Test basic creation of PI0 processor."""
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Check processor names
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
# Step 3 would be TokenizerProcessorStep but it's mocked
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_pi0_newline_processor_single_task():
"""Test Pi0NewLineProcessor with single task string."""
processor = Pi0NewLineProcessor()
# Test with task that doesn't have newline
transition = create_transition(complementary_data={"task": "test task"})
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
# Test with task that already has newline
transition = create_transition(complementary_data={"task": "test task\n"})
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
def test_pi0_newline_processor_list_of_tasks():
"""Test Pi0NewLineProcessor with list of task strings."""
processor = Pi0NewLineProcessor()
# Test with list of tasks
tasks = ["task1", "task2\n", "task3"]
transition = create_transition(complementary_data={"task": tasks})
result = processor(transition)
expected = ["task1\n", "task2\n", "task3\n"]
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
def test_pi0_newline_processor_empty_transition():
"""Test Pi0NewLineProcessor with empty transition."""
processor = Pi0NewLineProcessor()
# Test with no complementary_data
transition = create_transition()
result = processor(transition)
assert result == transition
# Test with complementary_data but no task
transition = create_transition(complementary_data={"other": "data"})
result = processor(transition)
assert result == transition
# Test with None task
transition = create_transition(complementary_data={"task": None})
result = processor(transition)
assert result == transition
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_cuda():
"""Test PI0 processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Create CPU data
observation = {
OBS_STATE: torch.randn(10),
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation, action, complementary_data={"task": "test task"})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data is on CUDA
assert processed[OBS_STATE].device.type == "cuda"
assert processed[OBS_IMAGE].device.type == "cuda"
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_accelerate_scenario():
"""Test PI0 processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
observation = {
OBS_STATE: torch.randn(1, 10).to(device),
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data stays on same GPU
assert processed[OBS_STATE].device == device
assert processed[OBS_IMAGE].device == device
assert processed[TransitionKey.ACTION.value].device == device
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_pi0_processor_multi_gpu():
"""Test PI0 processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Simulate data on different GPU
device = torch.device("cuda:1")
observation = {
OBS_STATE: torch.randn(1, 10).to(device),
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data stays on cuda:1
assert processed[OBS_STATE].device == device
assert processed[OBS_IMAGE].device == device
assert processed[TransitionKey.ACTION.value].device == device
def test_pi0_processor_without_stats():
"""Test PI0 processor creation without dataset statistics."""
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
)
# Should still create processors
assert preprocessor is not None
assert postprocessor is not None
def test_pi0_newline_processor_state_dict():
"""Test Pi0NewLineProcessor state dict methods."""
processor = Pi0NewLineProcessor()
# Test state_dict (should be empty)
state = processor.state_dict()
assert state == {}
# Test load_state_dict (should do nothing)
processor.load_state_dict({})
# Test reset (should do nothing)
processor.reset()
# Test get_config
config = processor.get_config()
assert config == {}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
stats = create_default_stats()
config.device = "cuda"
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, _ = make_pi0_pre_post_processors(
config,
stats,
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
norm_step = step # Now type checker knows this is NormalizerProcessorStep
modified_steps.append(
NormalizerProcessorStep(
features=norm_step.features,
norm_map=norm_step.norm_map,
stats=norm_step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
batch = transition_to_batch(transition)
# Process through full pipeline
processed = preprocessor(batch)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[OBS_STATE].dtype == torch.bfloat16
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check