Add Mistral, Qwen, Gemini support; Fix minor bugs
This commit is contained in:
@@ -30,7 +30,7 @@ def _execute_command(command: List[str]) -> None:
|
||||
p = subprocess.Popen(command)
|
||||
p.wait()
|
||||
else:
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True)
|
||||
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=60, text=True, encoding="utf-8")
|
||||
if result.returncode != 0:
|
||||
raise Exception("\033[91m" + result.stdout + result.stderr + "\033[0m")
|
||||
return result.stdout
|
||||
|
||||
@@ -328,6 +328,9 @@ def check_structure_sim(src_path, tgt_path):
|
||||
Check if the structure of the two images are similar
|
||||
gimp:2a729ded-3296-423d-aec4-7dd55ed5fbb3
|
||||
"""
|
||||
if src_path is None or tgt_path is None:
|
||||
return 0.
|
||||
|
||||
img_src = Image.open(src_path)
|
||||
img_tgt = Image.open(tgt_path)
|
||||
structure_same = structure_check_by_ssim(img_src, img_tgt)
|
||||
|
||||
@@ -27,7 +27,7 @@
|
||||
"command": [
|
||||
"python",
|
||||
"-c",
|
||||
"import pyautogui; import time; time.sleep(1); pyautogui.press(\"down\", presses=40, interval=0.1); time.sleep(1); pyautogui.scroll(-2)"
|
||||
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=40, interval=10); time.sleep(1); pyautogui.scroll(-2)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
"command": [
|
||||
"python",
|
||||
"-c",
|
||||
"import pyautogui; import time; time.sleep(1); pyautogui.press(\"down\", presses=8); time.sleep(1); pyautogui.scroll(-2)"
|
||||
"import pyautogui; import time; time.sleep(5); pyautogui.press(\"down\", presses=8, interval=3); time.sleep(1); pyautogui.scroll(-2)"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
# todo: needs to be refactored
|
||||
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import google.generativeai as genai
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
||||
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
||||
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
||||
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
||||
|
||||
|
||||
class GeminiPro_Agent:
|
||||
def __init__(self, api_key, instruction, model='gemini-pro', max_tokens=300, temperature=0.0,
|
||||
action_space="computer_13"):
|
||||
genai.configure(api_key=api_key)
|
||||
self.instruction = instruction
|
||||
self.model = genai.GenerativeModel(model)
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
|
||||
self.trajectory = [
|
||||
{
|
||||
"role": "system",
|
||||
"parts": [
|
||||
{
|
||||
"computer_13": SYS_PROMPT_ACTION,
|
||||
"pyautogui": SYS_PROMPT_CODE
|
||||
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
Only support single-round conversation, only fill-in the last desktop screenshot.
|
||||
"""
|
||||
accessibility_tree = obs["accessibility_tree"]
|
||||
|
||||
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(leaf_nodes)
|
||||
|
||||
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
linearized_accessibility_tree += node.tag + "\t"
|
||||
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
linearized_accessibility_tree += node.attrib.get(
|
||||
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
||||
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
||||
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
"parts": [
|
||||
"Given the XML format of accessibility tree (convert and formatted into table) as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
linearized_accessibility_tree)]
|
||||
})
|
||||
|
||||
# todo: Remove this step once the Gemini supports multi-round conversation
|
||||
all_message_str = ""
|
||||
for i in range(len(self.trajectory)):
|
||||
if i == 0:
|
||||
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
|
||||
elif i % 2 == 1:
|
||||
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
|
||||
else:
|
||||
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
|
||||
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
|
||||
|
||||
print("All message: >>>>>>>>>>>>>>>> ")
|
||||
print(
|
||||
all_message_str
|
||||
)
|
||||
|
||||
message_for_gemini = {
|
||||
"role": "user",
|
||||
"parts": [all_message_str]
|
||||
}
|
||||
|
||||
traj_to_show = []
|
||||
for i in range(len(self.trajectory)):
|
||||
traj_to_show.append(self.trajectory[i]["parts"][0])
|
||||
if len(self.trajectory[i]["parts"]) > 1:
|
||||
traj_to_show.append("screenshot_obs")
|
||||
|
||||
print("Trajectory:", traj_to_show)
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
message_for_gemini,
|
||||
generation_config={
|
||||
"max_output_tokens": self.max_tokens,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
)
|
||||
break
|
||||
except:
|
||||
print("Failed to generate response, retrying...")
|
||||
time.sleep(5)
|
||||
pass
|
||||
|
||||
try:
|
||||
response_text = response.text
|
||||
except:
|
||||
return []
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response_text)
|
||||
except:
|
||||
print("Failed to parse action from response:", response_text)
|
||||
actions = []
|
||||
|
||||
return actions
|
||||
|
||||
def parse_actions(self, response: str):
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
elif self.action_space == "pyautogui":
|
||||
actions = parse_code_from_string(response)
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
"role": "assistant",
|
||||
"parts": [response]
|
||||
})
|
||||
|
||||
return actions
|
||||
@@ -1,115 +0,0 @@
|
||||
# todo: needs to be refactored
|
||||
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import PIL.Image
|
||||
import google.generativeai as genai
|
||||
|
||||
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
||||
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
||||
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
||||
|
||||
|
||||
class GeminiProV_Agent:
|
||||
def __init__(self, api_key, instruction, model='gemini-pro-vision', max_tokens=300, temperature=0.0,
|
||||
action_space="computer_13"):
|
||||
genai.configure(api_key=api_key)
|
||||
self.instruction = instruction
|
||||
self.model = genai.GenerativeModel(model)
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
|
||||
self.trajectory = [
|
||||
{
|
||||
"role": "system",
|
||||
"parts": [
|
||||
{
|
||||
"computer_13": SYS_PROMPT_ACTION,
|
||||
"pyautogui": SYS_PROMPT_CODE
|
||||
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
Only support single-round conversation, only fill-in the last desktop screenshot.
|
||||
"""
|
||||
img = PIL.Image.open(obs["screenshot"])
|
||||
self.trajectory.append({
|
||||
"role": "user",
|
||||
"parts": ["What's the next step that you will do to help with the task?", img]
|
||||
})
|
||||
|
||||
# todo: Remove this step once the Gemini supports multi-round conversation
|
||||
all_message_str = ""
|
||||
for i in range(len(self.trajectory)):
|
||||
if i == 0:
|
||||
all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n"
|
||||
elif i % 2 == 1:
|
||||
all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n"
|
||||
else:
|
||||
all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
|
||||
all_message_str += all_message_template.format(self.trajectory[i]["parts"][0])
|
||||
|
||||
message_for_gemini = {
|
||||
"role": "user",
|
||||
"parts": [all_message_str, img]
|
||||
}
|
||||
|
||||
traj_to_show = []
|
||||
for i in range(len(self.trajectory)):
|
||||
traj_to_show.append(self.trajectory[i]["parts"][0])
|
||||
if len(self.trajectory[i]["parts"]) > 1:
|
||||
traj_to_show.append("screenshot_obs")
|
||||
|
||||
print("Trajectory:", traj_to_show)
|
||||
|
||||
while True:
|
||||
try:
|
||||
response = self.model.generate_content(
|
||||
message_for_gemini,
|
||||
generation_config={
|
||||
"max_output_tokens": self.max_tokens,
|
||||
"temperature": self.temperature
|
||||
}
|
||||
)
|
||||
break
|
||||
except:
|
||||
print("Failed to generate response, retrying...")
|
||||
time.sleep(5)
|
||||
pass
|
||||
|
||||
try:
|
||||
response_text = response.text
|
||||
except:
|
||||
return []
|
||||
|
||||
try:
|
||||
actions = self.parse_actions(response_text)
|
||||
except:
|
||||
print("Failed to parse action from response:", response_text)
|
||||
actions = []
|
||||
|
||||
return actions
|
||||
|
||||
def parse_actions(self, response: str):
|
||||
# parse from the response
|
||||
if self.action_space == "computer_13":
|
||||
actions = parse_actions_from_string(response)
|
||||
elif self.action_space == "pyautogui":
|
||||
actions = parse_code_from_string(response)
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + self.action_space)
|
||||
|
||||
# add action into the trajectory
|
||||
self.trajectory.append({
|
||||
"role": "assistant",
|
||||
"parts": [response]
|
||||
})
|
||||
|
||||
return actions
|
||||
@@ -1,12 +1,20 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from PIL import Image
|
||||
from openai.error import (
|
||||
APIConnectionError,
|
||||
APIError,
|
||||
@@ -22,8 +30,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S
|
||||
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
||||
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
@@ -44,11 +50,13 @@ def linearize_accessibility_tree(accessibility_tree):
|
||||
linearized_accessibility_tree += node.tag + "\t"
|
||||
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
if node.text:
|
||||
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))) + "\t"
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper")\
|
||||
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(
|
||||
node.text.replace('"', '""'))) + "\t"
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(text.replace('"', '""'))) + "\t"
|
||||
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(
|
||||
text.replace('"', '""'))) + "\t"
|
||||
else:
|
||||
linearized_accessibility_tree += '""\t'
|
||||
linearized_accessibility_tree += node.attrib.get(
|
||||
@@ -145,10 +153,21 @@ def parse_code_from_som_string(input_string, masks):
|
||||
x, y, w, h = mask
|
||||
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
|
||||
|
||||
# reverse the mappings
|
||||
for mapping in mappings[::-1]:
|
||||
input_string = input_string.replace(mapping[0], mapping[1])
|
||||
def replace_tags_with_mappings(text, mappings):
|
||||
pattern = r'tag#\d+'
|
||||
matches = re.findall(pattern, text)
|
||||
|
||||
for match in matches:
|
||||
for mapping in mappings:
|
||||
if match == mapping[0]:
|
||||
text = text.replace(match, mapping[1])
|
||||
break
|
||||
logger.error("Predicting the tag with index {} failed.".format(match))
|
||||
return ""
|
||||
|
||||
return text
|
||||
|
||||
input_string = replace_tags_with_mappings(input_string, mappings)
|
||||
actions = parse_code_from_string(input_string)
|
||||
return actions
|
||||
|
||||
@@ -295,7 +314,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"url": f"data:image/png;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
@@ -314,7 +333,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"url": f"data:image/png;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
@@ -375,7 +394,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
@@ -421,7 +440,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
@@ -448,7 +467,7 @@ class GPT4v_Agent:
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||
"url": f"data:image/png;base64,{base64_image}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
@@ -510,32 +529,130 @@ class GPT4v_Agent:
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
|
||||
max_tries=3
|
||||
max_tries=10
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
if self.model.startswith("gpt"):
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
print("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
print("Failed to call LLM: " + retry_response.text)
|
||||
if response.status_code != 200:
|
||||
if response.json()['error']['code'] == "context_length_exceeded":
|
||||
print("Context length exceeded. Retrying with a smaller context.")
|
||||
payload["messages"] = payload["messages"][-1:]
|
||||
retry_response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
if retry_response.status_code != 200:
|
||||
print("Failed to call LLM: " + retry_response.text)
|
||||
return ""
|
||||
|
||||
print("Failed to call LLM: " + response.text)
|
||||
time.sleep(5)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
elif self.model.startswith("mistral"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
openai.api_key = "test"
|
||||
response = openai.ChatCompletion.create(
|
||||
messages=messages,
|
||||
model="Mixtral-8x7B-Instruct-v0.1",
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
try:
|
||||
return response['choices'][0]['message']['content']
|
||||
except Exception as e:
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
genai.api_key = api_key
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
return image
|
||||
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
gemini_message = {
|
||||
"role": message["role"],
|
||||
"parts": []
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
|
||||
# The gemini only support the last image as single image input
|
||||
if i == len(messages) - 1:
|
||||
for part in message["content"]:
|
||||
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
|
||||
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
|
||||
else:
|
||||
for part in message["content"]:
|
||||
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
|
||||
|
||||
gemini_messages.append(gemini_message)
|
||||
|
||||
response = genai.GenerativeModel(self.model).generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"max_output_tokens": max_tokens
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
return response.text
|
||||
except Exception as e:
|
||||
return ""
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
qwen_messages = []
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
qwen_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": part['image_url']['url']}) if part['type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
|
||||
response = dashscope.MultiModalConversation.call(model='qwen-vl-plus',
|
||||
messages=messages)
|
||||
# The response status_code is HTTPStatus.OK indicate success,
|
||||
# otherwise indicate request is failed, you can get error code
|
||||
# and message from code and message.
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
try:
|
||||
return response.json()['output']['choices'][0]['message']['content']
|
||||
except Exception as e:
|
||||
return ""
|
||||
else:
|
||||
print(response.code) # The error code.
|
||||
print(response.message) # The error message.
|
||||
return ""
|
||||
|
||||
print("Failed to call LLM: " + response.text)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
raise ValueError("Invalid model: " + self.model)
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user