Fixing PI0 Policy (#1297)
This commit is contained in:
committed by
GitHub
parent
697c76f75e
commit
ce6a26deeb
@@ -223,7 +223,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
return self.paligemma.model.get_image_features(image)
|
||||
|
||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
||||
return self.paligemma.language_model.embed_tokens(tokens)
|
||||
|
||||
# TODO: break down this huge forward into modules or functions
|
||||
def forward(
|
||||
@@ -235,7 +235,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
use_cache: Optional[bool] = None,
|
||||
fill_kv_cache: Optional[bool] = None,
|
||||
):
|
||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
||||
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||
|
||||
for hidden_states in inputs_embeds:
|
||||
# TODO this is very inefficient
|
||||
|
||||
Reference in New Issue
Block a user