Add Gemini Pro 1.5 Support

This commit is contained in:
Timothyxxx
2024-04-24 18:19:25 +08:00
parent b3acf21333
commit eaceddf917

View File

@@ -1,12 +1,15 @@
import base64 import base64
import hashlib
import json import json
import logging import logging
import os import os
import re import re
import tempfile
import time import time
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from pathlib import Path
from typing import Dict, List from typing import Dict, List
import backoff import backoff
@@ -32,6 +35,25 @@ def encode_image(image_content):
return base64.b64encode(image_content).decode('utf-8') return base64.b64encode(image_content).decode('utf-8')
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
def save_to_tmp_img_file(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
image.save(tmp_img_path)
return tmp_img_path
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree) # leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
@@ -695,14 +717,7 @@ class PromptAgent:
print("Failed to call LLM: ", response.status_code) print("Failed to call LLM: ", response.status_code)
return "" return ""
elif self.model.startswith("gemini"): elif self.model in ["gemini-pro", "gemini-pro-vision"]:
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
messages = payload["messages"] messages = payload["messages"]
max_tokens = payload["max_tokens"] max_tokens = payload["max_tokens"]
top_p = payload["top_p"] top_p = payload["top_p"]
@@ -732,7 +747,7 @@ class PromptAgent:
gemini_messages.append(gemini_message) gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message # the gemini not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system": if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0] gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0) gemini_messages.pop(0)
@@ -775,6 +790,93 @@ class PromptAgent:
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}") logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}") logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
return "" return ""
elif self.model == "gemini-1.5-pro-latest":
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
uploaded_files = []
# def upload_if_needed(pathname: str) -> list[str]:
# path = Path(pathname)
# hash_id = hashlib.sha256(path.read_bytes()).hexdigest()
# try:
# existing_file = genai.get_file(name=hash_id)
# return [existing_file.uri]
# except:
# pass
# uploaded_files.append(genai.upload_file(path=path, display_name=hash_id))
# return [uploaded_files[-1].uri]
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
for part in message["content"]:
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
if part['type'] == "image_url":
gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
elif part['type'] == "text":
gemini_message['parts'].append(part['text'])
else:
raise ValueError("Invalid content type: " + part['type'])
gemini_messages.append(gemini_message)
# the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter
system_instruction = None
if gemini_messages[0]['role'] == "system":
system_instruction = gemini_messages[0]['parts'][0]
gemini_messages.pop(0)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
logger.info("Generating content with Gemini model: %s", self.model)
request_options = {"timeout": 120}
gemini_model = genai.GenerativeModel(
self.model,
system_instruction=system_instruction
)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return ""
elif self.model.startswith("qwen"): elif self.model.startswith("qwen"):
messages = payload["messages"] messages = payload["messages"]
max_tokens = payload["max_tokens"] max_tokens = payload["max_tokens"]