diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index badfb4b8..11feca96 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Install pi0 extra dependencies: ```bash @@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 0e53bd34..d3903066 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): ```bash @@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + def get_optim_params(self) -> dict: return self.parameters()