Add disclaimer to pi0 from_pretrained (#1550)
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
Disclaimer: It is not expected to perform as well as the original implementation.
|
||||
|
||||
Install pi0 extra dependencies:
|
||||
```bash
|
||||
@@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy):
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""Override the from_pretrained method to display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
|
||||
" It is not expected to perform as well as the original implementation. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||
|
||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||
Disclaimer: It is not expected to perform as well as the original implementation.
|
||||
|
||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||
```bash
|
||||
@@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
"""Override the from_pretrained method to display important disclaimer."""
|
||||
print(
|
||||
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
|
||||
" It is not expected to perform as well as the original implementation. \n"
|
||||
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||
)
|
||||
return super().from_pretrained(*args, **kwargs)
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user