Fixing PI0 Policy (#1297)
This commit is contained in:
committed by
GitHub
parent
697c76f75e
commit
ce6a26deeb
@@ -223,7 +223,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||||||
return self.paligemma.model.get_image_features(image)
|
return self.paligemma.model.get_image_features(image)
|
||||||
|
|
||||||
def embed_language_tokens(self, tokens: torch.Tensor):
|
def embed_language_tokens(self, tokens: torch.Tensor):
|
||||||
return self.paligemma.language_model.model.embed_tokens(tokens)
|
return self.paligemma.language_model.embed_tokens(tokens)
|
||||||
|
|
||||||
# TODO: break down this huge forward into modules or functions
|
# TODO: break down this huge forward into modules or functions
|
||||||
def forward(
|
def forward(
|
||||||
@@ -235,7 +235,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
|||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
fill_kv_cache: Optional[bool] = None,
|
fill_kv_cache: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
models = [self.paligemma.language_model.model, self.gemma_expert.model]
|
models = [self.paligemma.language_model, self.gemma_expert.model]
|
||||||
|
|
||||||
for hidden_states in inputs_embeds:
|
for hidden_states in inputs_embeds:
|
||||||
# TODO this is very inefficient
|
# TODO this is very inefficient
|
||||||
|
|||||||
Reference in New Issue
Block a user