From 459c95197ba1114ac8c3f538786ede005598f5e9 Mon Sep 17 00:00:00 2001 From: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com> Date: Wed, 11 Jun 2025 00:46:41 +0800 Subject: [PATCH] fix: update pi0 dependency version constraint (#1247) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- lerobot/common/policies/pi0/paligemma_with_expert.py | 6 +++++- lerobot/common/policies/pi0fast/modeling_pi0fast.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce600..49c844c7b 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -216,7 +216,11 @@ class PaliGemmaWithExpertModel(PreTrainedModel): param.data = param.data.to(dtype=torch.bfloat16) def embed_image(self, image: torch.Tensor): - return self.paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.paligemma, "get_image_features"): + return self.paligemma.get_image_features(image) + else: + return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.language_model.model.embed_tokens(tokens) diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 4996b1a08..a2df40f26 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -878,7 +878,11 @@ class PI0FAST(nn.Module): return actions def embed_image(self, image: torch.Tensor): - return self.pi0_paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.pi0_paligemma, "get_image_features"): + return self.pi0_paligemma.get_image_features(image) + else: + return self.pi0_paligemma.model.get_image_features(image) def embed_inputs( self,