From 17efa2ff8e71d3e9097ac32cb183e46919dea015 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 21 Jul 2025 10:57:35 +0200 Subject: [PATCH] Add disclaimer to pi0 from_pretrained (#1550) --- src/lerobot/policies/pi0/modeling_pi0.py | 11 +++++++++++ src/lerobot/policies/pi0fast/modeling_pi0fast.py | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index badfb4b8..11feca96 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Install pi0 extra dependencies: ```bash @@ -260,6 +261,16 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: """Predict a chunk of actions given environment observations.""" diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index 0e53bd34..d3903066 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -21,6 +21,7 @@ [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. +Disclaimer: It is not expected to perform as well as the original implementation. Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): ```bash @@ -162,6 +163,16 @@ class PI0FASTPolicy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def from_pretrained(cls, *args, **kwargs): + """Override the from_pretrained method to display important disclaimer.""" + print( + "⚠️ DISCLAIMER: The PI0FAST model is ported from JAX by the Hugging Face team. \n" + " It is not expected to perform as well as the original implementation. \n" + " Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + return super().from_pretrained(*args, **kwargs) + def get_optim_params(self) -> dict: return self.parameters()