diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce600..49c844c7b 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -216,7 +216,11 @@ class PaliGemmaWithExpertModel(PreTrainedModel): param.data = param.data.to(dtype=torch.bfloat16) def embed_image(self, image: torch.Tensor): - return self.paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.paligemma, "get_image_features"): + return self.paligemma.get_image_features(image) + else: + return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): return self.paligemma.language_model.model.embed_tokens(tokens) diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 4996b1a08..a2df40f26 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -878,7 +878,11 @@ class PI0FAST(nn.Module): return actions def embed_image(self, image: torch.Tensor): - return self.pi0_paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.pi0_paligemma, "get_image_features"): + return self.pi0_paligemma.get_image_features(image) + else: + return self.pi0_paligemma.model.get_image_features(image) def embed_inputs( self,