Add Gemini Pro 1.5 Support
This commit is contained in:
@@ -1,12 +1,15 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
@@ -32,6 +35,25 @@ def encode_image(image_content):
|
|||||||
return base64.b64encode(image_content).decode('utf-8')
|
return base64.b64encode(image_content).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def encoded_img_to_pil_img(data_str):
|
||||||
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||||
|
image_data = base64.b64decode(base64_str)
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_tmp_img_file(data_str):
|
||||||
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||||
|
image_data = base64.b64decode(base64_str)
|
||||||
|
image = Image.open(BytesIO(image_data))
|
||||||
|
|
||||||
|
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
|
||||||
|
image.save(tmp_img_path)
|
||||||
|
|
||||||
|
return tmp_img_path
|
||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||||
@@ -695,14 +717,7 @@ class PromptAgent:
|
|||||||
print("Failed to call LLM: ", response.status_code)
|
print("Failed to call LLM: ", response.status_code)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
elif self.model.startswith("gemini"):
|
elif self.model in ["gemini-pro", "gemini-pro-vision"]:
|
||||||
def encoded_img_to_pil_img(data_str):
|
|
||||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
|
||||||
image_data = base64.b64decode(base64_str)
|
|
||||||
image = Image.open(BytesIO(image_data))
|
|
||||||
|
|
||||||
return image
|
|
||||||
|
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
top_p = payload["top_p"]
|
top_p = payload["top_p"]
|
||||||
@@ -732,7 +747,7 @@ class PromptAgent:
|
|||||||
|
|
||||||
gemini_messages.append(gemini_message)
|
gemini_messages.append(gemini_message)
|
||||||
|
|
||||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
# the gemini not support system message in our endpoint, so we concatenate it at the first user message
|
||||||
if gemini_messages[0]['role'] == "system":
|
if gemini_messages[0]['role'] == "system":
|
||||||
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
|
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
|
||||||
gemini_messages.pop(0)
|
gemini_messages.pop(0)
|
||||||
@@ -775,6 +790,93 @@ class PromptAgent:
|
|||||||
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||||
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
elif self.model == "gemini-1.5-pro-latest":
|
||||||
|
messages = payload["messages"]
|
||||||
|
max_tokens = payload["max_tokens"]
|
||||||
|
top_p = payload["top_p"]
|
||||||
|
temperature = payload["temperature"]
|
||||||
|
|
||||||
|
uploaded_files = []
|
||||||
|
|
||||||
|
# def upload_if_needed(pathname: str) -> list[str]:
|
||||||
|
# path = Path(pathname)
|
||||||
|
# hash_id = hashlib.sha256(path.read_bytes()).hexdigest()
|
||||||
|
# try:
|
||||||
|
# existing_file = genai.get_file(name=hash_id)
|
||||||
|
# return [existing_file.uri]
|
||||||
|
# except:
|
||||||
|
# pass
|
||||||
|
# uploaded_files.append(genai.upload_file(path=path, display_name=hash_id))
|
||||||
|
# return [uploaded_files[-1].uri]
|
||||||
|
|
||||||
|
gemini_messages = []
|
||||||
|
for i, message in enumerate(messages):
|
||||||
|
role_mapping = {
|
||||||
|
"assistant": "model",
|
||||||
|
"user": "user",
|
||||||
|
"system": "system"
|
||||||
|
}
|
||||||
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
|
|
||||||
|
# The gemini only support the last image as single image input
|
||||||
|
for part in message["content"]:
|
||||||
|
gemini_message = {
|
||||||
|
"role": role_mapping[message["role"]],
|
||||||
|
"parts": []
|
||||||
|
}
|
||||||
|
if part['type'] == "image_url":
|
||||||
|
gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
|
||||||
|
elif part['type'] == "text":
|
||||||
|
gemini_message['parts'].append(part['text'])
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid content type: " + part['type'])
|
||||||
|
|
||||||
|
gemini_messages.append(gemini_message)
|
||||||
|
|
||||||
|
# the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter
|
||||||
|
system_instruction = None
|
||||||
|
if gemini_messages[0]['role'] == "system":
|
||||||
|
system_instruction = gemini_messages[0]['parts'][0]
|
||||||
|
gemini_messages.pop(0)
|
||||||
|
|
||||||
|
api_key = os.environ.get("GENAI_API_KEY")
|
||||||
|
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||||
|
genai.configure(api_key=api_key)
|
||||||
|
logger.info("Generating content with Gemini model: %s", self.model)
|
||||||
|
request_options = {"timeout": 120}
|
||||||
|
gemini_model = genai.GenerativeModel(
|
||||||
|
self.model,
|
||||||
|
system_instruction=system_instruction
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = gemini_model.generate_content(
|
||||||
|
gemini_messages,
|
||||||
|
generation_config={
|
||||||
|
"candidate_count": 1,
|
||||||
|
"max_output_tokens": max_tokens,
|
||||||
|
"top_p": top_p,
|
||||||
|
"temperature": temperature
|
||||||
|
},
|
||||||
|
safety_settings={
|
||||||
|
"harassment": "block_none",
|
||||||
|
"hate": "block_none",
|
||||||
|
"sex": "block_none",
|
||||||
|
"danger": "block_none"
|
||||||
|
},
|
||||||
|
request_options=request_options
|
||||||
|
)
|
||||||
|
for uploaded_file in uploaded_files:
|
||||||
|
genai.delete_file(name=uploaded_file.name)
|
||||||
|
return response.text
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
|
||||||
|
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||||
|
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||||
|
for uploaded_file in uploaded_files:
|
||||||
|
genai.delete_file(name=uploaded_file.name)
|
||||||
|
return ""
|
||||||
|
|
||||||
elif self.model.startswith("qwen"):
|
elif self.model.startswith("qwen"):
|
||||||
messages = payload["messages"]
|
messages = payload["messages"]
|
||||||
max_tokens = payload["max_tokens"]
|
max_tokens = payload["max_tokens"]
|
||||||
|
|||||||
Reference in New Issue
Block a user