Files
sci-gui-agent-benchmark/mm_agents/agent.py
Timothyxxx 4db207fc27 Merge remote-tracking branch 'origin/main'
# Conflicts:
#	mm_agents/agent.py
#	run.py
2024-03-15 21:10:32 +08:00

762 lines
31 KiB
Python

import base64
import json
import logging
import os
import re
import time
import uuid
import openai
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
from google.api_core.exceptions import InvalidArgument
import backoff
import dashscope
import google.generativeai as genai
import requests
from PIL import Image
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# todo: cross-check with visualwebarena
logger = logging.getLogger("desktopenv.agent")
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(
node.text.replace('"', '""'))) + "\t"
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(
text.replace('"', '""'))) + "\t"
else:
linearized_accessibility_tree += '""\t'
linearized_accessibility_tree += node.attrib.get(
'{uri:deskat:component.at-spi.gnome.org}screencoord', "") + "\t"
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
return linearized_accessibility_tree
def tag_screenshot(screenshot, accessibility_tree):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
# Make tag screenshot
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
return marks, drew_nodes, tagged_screenshot_file_path
def parse_actions_from_string(input_string):
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
# Search for a JSON string within the input string
actions = []
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
try:
action_dict = json.loads(input_string)
return [action_dict]
except json.JSONDecodeError:
raise ValueError("Invalid response format: " + input_string)
def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n")
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
# This regular expression will match both ```code``` and ```python code```
# and capture the `code` part. It uses a non-greedy match for the content inside.
pattern = r"```(?:\w+\s+)?(.*?)```"
# Find all non-overlapping matches in the string
matches = re.findall(pattern, input_string, re.DOTALL)
# The regex above captures the content inside the triple backticks.
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
# so the code inside backticks can span multiple lines.
# matches now contains all the captured code snippets
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def parse_code_from_som_string(input_string, masks):
# parse the output string by masks
tag_vars = ""
for i, mask in enumerate(masks):
x, y, w, h = mask
tag_vars += "tag_" + str(i + 1) + "=" + "({}, {})".format(int(x + w // 2), int(y + h // 2))
tag_vars += "\n"
actions = parse_code_from_string(input_string)
for i, action in enumerate(actions):
if action.strip() in ['WAIT', 'DONE', 'FAIL']:
pass
else:
action = tag_vars + action
actions[i] = action
return actions
class PromptAgent:
def __init__(
self,
model="gpt-4-vision-preview",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
max_trajectory_length=3
):
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.thoughts = []
self.actions = []
self.observations = []
if observation_type == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_SEEACT
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + observation_type)
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
# Prepare the payload for the API call
messages = []
masks = None
messages.append({
"role": "system",
"content": [
{
"type": "text",
"text": system_message
},
]
})
# Append trajectory
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \
, "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
_thoughts = self.thoughts[-self.max_trajectory_length:]
else:
_observations = self.observations
_actions = self.actions
_thoughts = self.thoughts
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type in ["som", "seeact"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type == "screenshot":
_screenshot = previous_obs["screenshot"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type == "a11y_tree":
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
}
]
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action"
},
]
})
# {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
if self.observation_type == "screenshot"
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": None,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
}
]
})
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "seeact":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
if self.observation_type == "seeact":
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
}
]
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
ACTION_GROUNDING_PROMPT_SEEACT)
}
]
})
logger.info("Generating content with GPT model: %s", self.model)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
self.thoughts.append(response)
except ValueError as e:
print("Failed to parse action from response", e)
actions = None
self.thoughts.append("")
return actions
@backoff.on_exception(
backoff.expo,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
}
# logger.info("Generating content with GPT model: %s", self.model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
if retry_response.status_code != 200:
logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return ""
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
# elif self.model.startswith("mistral"):
# print("Call mistral")
# messages = payload["messages"]
# max_tokens = payload["max_tokens"]
#
# misrtal_messages = []
#
# for i, message in enumerate(messages):
# mistral_message = {
# "role": message["role"],
# "content": []
# }
#
# for part in message["content"]:
# mistral_message['content'] = part['text'] if part['type'] == "text" else None
#
# misrtal_messages.append(mistral_message)
#
# # the mistral not support system message in our endpoint, so we concatenate it at the first user message
# if misrtal_messages[0]['role'] == "system":
# misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content']
# misrtal_messages.pop(0)
#
# # openai.api_base = "http://localhost:8000/v1"
# # openai.api_key = "test"
# # response = openai.ChatCompletion.create(
# # messages=misrtal_messages,
# # model="Mixtral-8x7B-Instruct-v0.1"
# # )
#
# from openai import OpenAI
# TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2"
#
# client = OpenAI(api_key=TOGETHER_API_KEY,
# base_url='https://api.together.xyz',
# )
# logger.info("Generating content with Mistral model: %s", self.model)
# response = client.chat.completions.create(
# messages=misrtal_messages,
# model="mistralai/Mixtral-8x7B-Instruct-v0.1",
# max_tokens=1024
# )
#
# try:
# # return response['choices'][0]['message']['content']
# return response.choices[0].message.content
# except Exception as e:
# print("Failed to call LLM: " + str(e))
# return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
if i == len(messages) - 1:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
else:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0)
# since the gemini-pro-vision donnot support multi-turn message
if self.model == "gemini-pro-vision":
message_history_str = ""
for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
# print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
logger.info("Generating content with Gemini model: %s", self.model)
response = genai.GenerativeModel(self.model).generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
}
)
try:
return response.text
except Exception as e:
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
qwen_messages = []
for i, message in enumerate(messages):
qwen_message = {
"role": message["role"],
"content": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(
model='qwen-vl-plus',
messages=messages, # todo: add the hyperparameters
)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response.status_code == HTTPStatus.OK:
try:
return response.json()['output']['choices'][0]['message']['content']
except Exception as e:
return ""
else:
print(response.code) # The error code.
print(response.message) # The error message.
return ""
else:
raise ValueError("Invalid model: " + self.model)
def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
self.actions.append(actions)
return actions
elif self.observation_type in ["som", "seeact"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)
elif self.action_space == "pyautogui":
actions = parse_code_from_som_string(response, masks)
else:
raise ValueError("Invalid action space: " + self.action_space)
self.actions.append(actions)
return actions
def reset(self):
self.thoughts = []
self.actions = []
self.observations = []