Add disclaimer to pi0 from_pretrained (#1550)
This commit is contained in:
@@ -21,6 +21,7 @@
|
|||||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||||
|
|
||||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||||
|
Disclaimer: It is not expected to perform as well as the original implementation.
|
||||||
|
|
||||||
Install pi0 extra dependencies:
|
Install pi0 extra dependencies:
|
||||||
```bash
|
```bash
|
||||||
@@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
"""Override the from_pretrained method to display important disclaimer."""
|
||||||
|
print(
|
||||||
|
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
|
||||||
|
" It is not expected to perform as well as the original implementation. \n"
|
||||||
|
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||||
|
)
|
||||||
|
return super().from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
|
|||||||
@@ -21,6 +21,7 @@
|
|||||||
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
[Jax code](https://github.com/Physical-Intelligence/openpi)
|
||||||
|
|
||||||
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
|
||||||
|
Disclaimer: It is not expected to perform as well as the original implementation.
|
||||||
|
|
||||||
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`):
|
||||||
```bash
|
```bash
|
||||||
@@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy):
|
|||||||
"""This should be called whenever the environment is reset."""
|
"""This should be called whenever the environment is reset."""
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
"""Override the from_pretrained method to display important disclaimer."""
|
||||||
|
print(
|
||||||
|
"⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n"
|
||||||
|
" It is not expected to perform as well as the original implementation. \n"
|
||||||
|
" Original implementation: https://github.com/Physical-Intelligence/openpi"
|
||||||
|
)
|
||||||
|
return super().from_pretrained(*args, **kwargs)
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user