Add disclaimer to pi0 from_pretrained (#1550)

This commit is contained in:
Michel Aractingi
2025-07-21 10:57:35 +02:00
committed by GitHub
parent 26cb4614c9
commit 17efa2ff8e
2 changed files with 22 additions and 0 deletions

View File

@@ -21,6 +21,7 @@
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.
Install pi0 extra dependencies:
```bash
@@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy):
def get_optim_params(self) -> dict:
return self.parameters()
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""

View File

@@ -21,6 +21,7 @@
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
```bash
@@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
def get_optim_params(self) -> dict:
return self.parameters()