Add Mistral, Qwen, Gemini support; Fix minor bugs

This commit is contained in:
Timothyxxx
2024-02-01 16:55:38 +08:00
parent 4ef0dd59af
commit 59e2417a08
7 changed files with 156 additions and 287 deletions

View File

@@ -1,12 +1,20 @@
import base64
import json
import logging
import os
import re
import time
import uuid
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
from PIL import Image
from openai.error import (
APIConnectionError,
APIError,
@@ -22,8 +30,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
import logging
logger = logging.getLogger("desktopenv.agent")
@@ -44,11 +50,13 @@ def linearize_accessibility_tree(accessibility_tree):
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))) + "\t"
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper")\
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(
node.text.replace('"', '""'))) + "\t"
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(text.replace('"', '""'))) + "\t"
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(
text.replace('"', '""'))) + "\t"
else:
linearized_accessibility_tree += '""\t'
linearized_accessibility_tree += node.attrib.get(
@@ -145,10 +153,21 @@ def parse_code_from_som_string(input_string, masks):
x, y, w, h = mask
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
# reverse the mappings
for mapping in mappings[::-1]:
input_string = input_string.replace(mapping[0], mapping[1])
def replace_tags_with_mappings(text, mappings):
pattern = r'tag#\d+'
matches = re.findall(pattern, text)
for match in matches:
for mapping in mappings:
if match == mapping[0]:
text = text.replace(match, mapping[1])
break
logger.error("Predicting the tag with index {} failed.".format(match))
return ""
return text
input_string = replace_tags_with_mappings(input_string, mappings)
actions = parse_code_from_string(input_string)
return actions
@@ -295,7 +314,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
@@ -314,7 +333,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
@@ -375,7 +394,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -421,7 +440,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -448,7 +467,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -510,32 +529,130 @@ class GPT4v_Agent:
@backoff.on_exception(
backoff.expo,
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
max_tries=3
max_tries=10
)
def call_llm(self, payload):
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if self.model.startswith("gpt"):
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
return ""
print("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif self.model.startswith("mistral"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
openai.api_base = "http://localhost:8000/v1"
openai.api_key = "test"
response = openai.ChatCompletion.create(
messages=messages,
model="Mixtral-8x7B-Instruct-v0.1",
max_tokens=max_tokens
)
try:
return response['choices'][0]['message']['content']
except Exception as e:
return ""
elif self.model.startswith("gemini"):
api_key = os.environ.get("GENAI_API_KEY")
genai.api_key = api_key
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
messages = payload["messages"]
max_tokens = payload["max_tokens"]
gemini_messages = []
for i, message in enumerate(messages):
gemini_message = {
"role": message["role"],
"parts": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
if i == len(messages) - 1:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
else:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
gemini_messages.append(gemini_message)
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"max_output_tokens": max_tokens
}
)
try:
return response.text
except Exception as e:
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
qwen_messages = []
for i, message in enumerate(messages):
qwen_message = {
"role": message["role"],
"content": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
messages=messages)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response.status_code == HTTPStatus.OK:
try:
return response.json()['output']['choices'][0]['message']['content']
except Exception as e:
return ""
else:
print(response.code) # The error code.
print(response.message) # The error message.
return ""
print("Failed to call LLM: " + response.text)
return ""
else:
return response.json()['choices'][0]['message']['content']
raise ValueError("Invalid model: " + self.model)
def parse_actions(self, response: str, masks=None):