Add Mistral, Qwen, Gemini support; Fix minor bugs

This commit is contained in:
Timothyxxx
2024-02-01 16:55:38 +08:00
parent 4ef0dd59af
commit 59e2417a08
7 changed files with 156 additions and 287 deletions

View File

@@ -30,7 +30,7 @@ def _execute_command(command: List[str]) -> None:
p = subprocess.Popen(command)
p.wait()
else:
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True, encoding="utf-8")
if result.returncode != 0:
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
return result.stdout

View File

@@ -328,6 +328,9 @@ def check_structure_sim(src_path, tgt_path):
Check if the structure of the two images are similar
gimp:2a729ded-3296-423d-aec4-7dd55ed5fbb3
"""
if src_path is None or tgt_path is None:
return 0.
img_src = Image.open(src_path)
img_tgt = Image.open(tgt_path)
structure_same = structure_check_by_ssim(img_src, img_tgt)

View File

@@ -27,7 +27,7 @@
"command": [
"python",
"-c",
"import pyautogui; import time; time.sleep(1); pyautogui.press(\"down\", presses=40, interval=0.1); time.sleep(1); pyautogui.scroll(-2)"
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=40, interval=10); time.sleep(1); pyautogui.scroll(-2)"
]
}
}

View File

@@ -38,7 +38,7 @@
"command": [
"python",
"-c",
"import pyautogui; import time; time.sleep(1); pyautogui.press(\"down\", presses=8); time.sleep(1); pyautogui.scroll(-2)"
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=8, interval=3); time.sleep(1); pyautogui.scroll(-2)"
]
}
}

View File

@@ -1,136 +0,0 @@
# todo: needs to be refactored
import time
from typing import Dict, List
import google.generativeai as genai
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
class GeminiPro_Agent:
def __init__(self, api_key, instruction, model='gemini-pro', max_tokens=300, temperature=0.0,
action_space="computer_13"):
genai.configure(api_key=api_key)
self.instruction = instruction
self.model = genai.GenerativeModel(model)
self.max_tokens = max_tokens
self.temperature = temperature
self.action_space = action_space
self.trajectory = [
{
"role": "system",
"parts": [
{
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
]
}
]
def predict(self, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
Only support single-round conversation, only fill-in the last desktop screenshot.
"""
accessibility_tree = obs["accessibility_tree"]
leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(leaf_nodes)
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
linearized_accessibility_tree += node.attrib.get(
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
self.trajectory.append({
"role": "user",
"parts": [
"Given the XML format of accessibility tree (convert and formatted into table) as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)]
})
# todo: Remove this step once the Gemini supports multi-round conversation
all_message_str = ""
for i in range(len(self.trajectory)):
if i == 0:
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
elif i % 2 == 1:
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
else:
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
print("All message: >>>>>>>>>>>>>>>> ")
print(
all_message_str
)
message_for_gemini = {
"role": "user",
"parts": [all_message_str]
}
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["parts"][0])
if len(self.trajectory[i]["parts"]) > 1:
traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
while True:
try:
response = self.model.generate_content(
message_for_gemini,
generation_config={
"max_output_tokens": self.max_tokens,
"temperature": self.temperature
}
)
break
except:
print("Failed to generate response, retrying...")
time.sleep(5)
pass
try:
response_text = response.text
except:
return []
try:
actions = self.parse_actions(response_text)
except:
print("Failed to parse action from response:", response_text)
actions = []
return actions
def parse_actions(self, response: str):
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"parts": [response]
})
return actions

View File

@@ -1,115 +0,0 @@
# todo: needs to be refactored
import time
from typing import Dict, List
import PIL.Image
import google.generativeai as genai
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
class GeminiProV_Agent:
def __init__(self, api_key, instruction, model='gemini-pro-vision', max_tokens=300, temperature=0.0,
action_space="computer_13"):
genai.configure(api_key=api_key)
self.instruction = instruction
self.model = genai.GenerativeModel(model)
self.max_tokens = max_tokens
self.temperature = temperature
self.action_space = action_space
self.trajectory = [
{
"role": "system",
"parts": [
{
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
]
}
]
def predict(self, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
Only support single-round conversation, only fill-in the last desktop screenshot.
"""
img = PIL.Image.open(obs["screenshot"])
self.trajectory.append({
"role": "user",
"parts": ["What's the next step that you will do to help with the task?", img]
})
# todo: Remove this step once the Gemini supports multi-round conversation
all_message_str = ""
for i in range(len(self.trajectory)):
if i == 0:
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
elif i % 2 == 1:
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
else:
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
message_for_gemini = {
"role": "user",
"parts": [all_message_str, img]
}
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["parts"][0])
if len(self.trajectory[i]["parts"]) > 1:
traj_to_show.append("screenshot_obs")
print("Trajectory:", traj_to_show)
while True:
try:
response = self.model.generate_content(
message_for_gemini,
generation_config={
"max_output_tokens": self.max_tokens,
"temperature": self.temperature
}
)
break
except:
print("Failed to generate response, retrying...")
time.sleep(5)
pass
try:
response_text = response.text
except:
return []
try:
actions = self.parse_actions(response_text)
except:
print("Failed to parse action from response:", response_text)
actions = []
return actions
def parse_actions(self, response: str):
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"parts": [response]
})
return actions

View File

@@ -1,12 +1,20 @@
import base64
import json
import logging
import os
import re
import time
import uuid
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
from PIL import Image
from openai.error import (
APIConnectionError,
APIError,
@@ -22,8 +30,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
import logging
logger = logging.getLogger("desktopenv.agent")
@@ -44,11 +50,13 @@ def linearize_accessibility_tree(accessibility_tree):
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))) + "\t"
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper")\
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(
node.text.replace('"', '""'))) + "\t"
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(text.replace('"', '""'))) + "\t"
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(
text.replace('"', '""'))) + "\t"
else:
linearized_accessibility_tree += '""\t'
linearized_accessibility_tree += node.attrib.get(
@@ -145,10 +153,21 @@ def parse_code_from_som_string(input_string, masks):
x, y, w, h = mask
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
# reverse the mappings
for mapping in mappings[::-1]:
input_string = input_string.replace(mapping[0], mapping[1])
def replace_tags_with_mappings(text, mappings):
pattern = r'tag#\d+'
matches = re.findall(pattern, text)
for match in matches:
for mapping in mappings:
if match == mapping[0]:
text = text.replace(match, mapping[1])
break
logger.error("Predicting the tag with index {} failed.".format(match))
return ""
return text
input_string = replace_tags_with_mappings(input_string, mappings)
actions = parse_code_from_string(input_string)
return actions
@@ -295,7 +314,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
@@ -314,7 +333,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
@@ -375,7 +394,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -421,7 +440,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -448,7 +467,7 @@ class GPT4v_Agent:
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
@@ -510,32 +529,130 @@ class GPT4v_Agent:
@backoff.on_exception(
backoff.expo,
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
max_tries=3
max_tries=10
)
def call_llm(self, payload):
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if self.model.startswith("gpt"):
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
print("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
if retry_response.status_code != 200:
print("Failed to call LLM: " + retry_response.text)
return ""
print("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif self.model.startswith("mistral"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
openai.api_base = "http://localhost:8000/v1"
openai.api_key = "test"
response = openai.ChatCompletion.create(
messages=messages,
model="Mixtral-8x7B-Instruct-v0.1",
max_tokens=max_tokens
)
try:
return response['choices'][0]['message']['content']
except Exception as e:
return ""
elif self.model.startswith("gemini"):
api_key = os.environ.get("GENAI_API_KEY")
genai.api_key = api_key
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
messages = payload["messages"]
max_tokens = payload["max_tokens"]
gemini_messages = []
for i, message in enumerate(messages):
gemini_message = {
"role": message["role"],
"parts": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
if i == len(messages) - 1:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
else:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
gemini_messages.append(gemini_message)
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"max_output_tokens": max_tokens
}
)
try:
return response.text
except Exception as e:
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
qwen_messages = []
for i, message in enumerate(messages):
qwen_message = {
"role": message["role"],
"content": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
messages=messages)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response.status_code == HTTPStatus.OK:
try:
return response.json()['output']['choices'][0]['message']['content']
except Exception as e:
return ""
else:
print(response.code) # The error code.
print(response.message) # The error message.
return ""
print("Failed to call LLM: " + response.text)
return ""
else:
return response.json()['choices'][0]['message']['content']
raise ValueError("Invalid model: " + self.model)
def parse_actions(self, response: str, masks=None):