Fixing PI0 Policy (#1297)

This commit is contained in:
Francesco Capuano
2025-06-14 19:25:50 +02:00
committed by GitHub
parent 697c76f75e
commit ce6a26deeb

View File

@@ -223,7 +223,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.model.embed_tokens(tokens)
return self.paligemma.language_model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
@@ -235,7 +235,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
use_cache: Optional[bool] = None,
fill_kv_cache: Optional[bool] = None,
):
models = [self.paligemma.language_model.model, self.gemma_expert.model]
models = [self.paligemma.language_model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient