fix: update pi0 dependency version constraint (#1247)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Yushun Xiang
2025-06-11 00:46:41 +08:00
committed by GitHub
parent 37748c83ca
commit 459c95197b
2 changed files with 10 additions and 2 deletions

View File

@@ -216,7 +216,11 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
return self.paligemma.get_image_features(image)
# Handle different transformers versions
if hasattr(self.paligemma, "get_image_features"):
return self.paligemma.get_image_features(image)
else:
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.model.embed_tokens(tokens)

View File

@@ -878,7 +878,11 @@ class PI0FAST(nn.Module):
return actions
def embed_image(self, image: torch.Tensor):
return self.pi0_paligemma.get_image_features(image)
# Handle different transformers versions
if hasattr(self.pi0_paligemma, "get_image_features"):
return self.pi0_paligemma.get_image_features(image)
else:
return self.pi0_paligemma.model.get_image_features(image)
def embed_inputs(
self,