566 lines
22 KiB
Python
566 lines
22 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import re
|
|
import uuid
|
|
from typing import Dict, List
|
|
|
|
import backoff
|
|
import requests
|
|
from openai.error import (
|
|
APIConnectionError,
|
|
APIError,
|
|
RateLimitError,
|
|
ServiceUnavailableError,
|
|
InvalidRequestError
|
|
)
|
|
|
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
|
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
|
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
|
|
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
|
|
|
|
import logging
|
|
|
|
logger = logging.getLogger("desktopenv.agent")
|
|
|
|
|
|
# Function to encode the image
|
|
def encode_image(image_path):
|
|
with open(image_path, "rb") as image_file:
|
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
|
|
|
def linearize_accessibility_tree(accessibility_tree):
|
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
|
filtered_nodes = filter_nodes(leaf_nodes)
|
|
|
|
linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n"
|
|
# Linearize the accessibility tree nodes into a table format
|
|
|
|
for node in filtered_nodes:
|
|
linearized_accessibility_tree += node.tag + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
|
if node.text:
|
|
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))) + "\t"
|
|
elif node.get("{uri:deskat:uia.windows.microsoft.org}class").endswith("EditWrapper")\
|
|
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
|
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
|
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(text.replace('"', '""'))) + "\t"
|
|
else:
|
|
linearized_accessibility_tree += '""\t'
|
|
linearized_accessibility_tree += node.attrib.get(
|
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
|
|
|
return linearized_accessibility_tree
|
|
|
|
|
|
def tag_screenshot(screenshot, accessibility_tree):
|
|
# Creat a tmp file to store the screenshot in random name
|
|
uuid_str = str(uuid.uuid4())
|
|
os.makedirs("tmp/images", exist_ok=True)
|
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
|
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
|
# Make tag screenshot
|
|
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
|
|
|
return marks, drew_nodes, tagged_screenshot_file_path
|
|
|
|
|
|
def parse_actions_from_string(input_string):
|
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
return [input_string.strip()]
|
|
# Search for a JSON string within the input string
|
|
actions = []
|
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
try:
|
|
action_dict = json.loads(input_string)
|
|
return [action_dict]
|
|
except json.JSONDecodeError:
|
|
raise ValueError("Invalid response format: " + input_string)
|
|
|
|
|
|
def parse_code_from_string(input_string):
|
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
return [input_string.strip()]
|
|
|
|
# This regular expression will match both ```code``` and ```python code```
|
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
|
# Find all non-overlapping matches in the string
|
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
|
|
# The regex above captures the content inside the triple backticks.
|
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
|
# so the code inside backticks can span multiple lines.
|
|
|
|
# matches now contains all the captured code snippets
|
|
|
|
codes = []
|
|
|
|
for match in matches:
|
|
match = match.strip()
|
|
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
|
|
|
|
if match in commands:
|
|
codes.append(match.strip())
|
|
elif match.split('\n')[-1] in commands:
|
|
if len(match.split('\n')) > 1:
|
|
codes.append("\n".join(match.split('\n')[:-1]))
|
|
codes.append(match.split('\n')[-1])
|
|
else:
|
|
codes.append(match)
|
|
|
|
return codes
|
|
|
|
|
|
def parse_code_from_som_string(input_string, masks):
|
|
# parse the output string by masks
|
|
mappings = []
|
|
for i, mask in enumerate(masks):
|
|
x, y, w, h = mask
|
|
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
|
|
|
|
# reverse the mappings
|
|
for mapping in mappings[::-1]:
|
|
input_string = input_string.replace(mapping[0], mapping[1])
|
|
|
|
actions = parse_code_from_string(input_string)
|
|
return actions
|
|
|
|
|
|
class GPT4v_Agent:
|
|
def __init__(
|
|
self,
|
|
api_key,
|
|
instruction,
|
|
model="gpt-4-vision-preview",
|
|
max_tokens=500,
|
|
action_space="computer_13",
|
|
exp="screenshot_a11y_tree"
|
|
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
|
):
|
|
|
|
self.instruction = instruction
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.action_space = action_space
|
|
self.exp = exp
|
|
self.max_trajectory_length = 3
|
|
|
|
self.headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
self.thoughts = []
|
|
self.actions = []
|
|
self.observations = []
|
|
|
|
if exp == "screenshot":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif exp == "a11y_tree":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif exp == "both":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif exp == "som":
|
|
if action_space == "computer_13":
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif exp == "seeact":
|
|
if action_space == "computer_13":
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_SEEACT
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
else:
|
|
raise ValueError("Invalid experiment type: " + exp)
|
|
|
|
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
|
self.instruction)
|
|
|
|
def predict(self, obs: Dict) -> List:
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
"""
|
|
|
|
# Prepare the payload for the API call
|
|
messages = []
|
|
masks = None
|
|
|
|
messages.append({
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": self.system_message
|
|
},
|
|
]
|
|
})
|
|
|
|
# Append trajectory
|
|
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \
|
|
, "The number of observations and actions should be the same."
|
|
|
|
if len(self.observations) > self.max_trajectory_length:
|
|
_observations = self.observations[-self.max_trajectory_length:]
|
|
_actions = self.actions[-self.max_trajectory_length:]
|
|
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
|
else:
|
|
_observations = self.observations
|
|
_actions = self.actions
|
|
_thoughts = self.thoughts
|
|
|
|
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
|
|
|
# {{{1
|
|
if self.exp == "both":
|
|
_screenshot = previous_obs["screenshot"]
|
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
_linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.exp in ["som", "seeact"]:
|
|
_screenshot = previous_obs["screenshot"]
|
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
logger.debug("LINEAR AT: %s", _linearized_accessibility_tree)
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
_linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.exp == "screenshot":
|
|
_screenshot = previous_obs["screenshot"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.exp == "a11y_tree":
|
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
_linearized_accessibility_tree)
|
|
}
|
|
]
|
|
})
|
|
else:
|
|
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
|
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action"
|
|
},
|
|
]
|
|
})
|
|
|
|
# {{{1
|
|
if self.exp in ["screenshot", "both"]:
|
|
base64_image = encode_image(obs["screenshot"])
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
|
|
if self.exp == "both":
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
else:
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": None
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
|
if self.exp == "screenshot"
|
|
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.exp == "a11y_tree":
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
|
|
self.observations.append({
|
|
"screenshot": None,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
}
|
|
]
|
|
})
|
|
elif self.exp == "som":
|
|
# Add som to the screenshot
|
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
|
base64_image = encode_image(tagged_screenshot)
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.exp == "seeact":
|
|
# Add som to the screenshot
|
|
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
|
base64_image = encode_image(tagged_screenshot)
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
|
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
else:
|
|
raise ValueError("Invalid experiment type: " + self.exp) # 1}}}
|
|
|
|
with open("messages.json", "w") as f:
|
|
f.write(json.dumps(messages, indent=4))
|
|
|
|
response = self.call_llm({
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": self.max_tokens
|
|
})
|
|
|
|
logger.debug("RESPONSE: %s", response)
|
|
|
|
if self.exp == "seeact":
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": response
|
|
}
|
|
]
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
|
|
ACTION_GROUNDING_PROMPT_SEEACT)
|
|
}
|
|
]
|
|
})
|
|
|
|
response = self.call_llm({
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": self.max_tokens
|
|
})
|
|
print(response)
|
|
|
|
try:
|
|
actions = self.parse_actions(response, masks)
|
|
self.thoughts.append(response)
|
|
except Exception as e:
|
|
print("Failed to parse action from response", e)
|
|
actions = None
|
|
self.thoughts.append("")
|
|
|
|
return actions
|
|
|
|
@backoff.on_exception(
|
|
backoff.expo,
|
|
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
|
|
max_tries=3
|
|
)
|
|
def call_llm(self, payload):
|
|
response = requests.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers=self.headers,
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
if response.json()['error']['code'] == "context_length_exceeded":
|
|
print("Context length exceeded. Retrying with a smaller context.")
|
|
payload["messages"] = payload["messages"][-1:]
|
|
retry_response = requests.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers=self.headers,
|
|
json=payload
|
|
)
|
|
if retry_response.status_code != 200:
|
|
print("Failed to call LLM: " + retry_response.text)
|
|
return ""
|
|
|
|
print("Failed to call LLM: " + response.text)
|
|
return ""
|
|
else:
|
|
return response.json()['choices'][0]['message']['content']
|
|
|
|
def parse_actions(self, response: str, masks=None):
|
|
|
|
if self.exp in ["screenshot", "a11y_tree", "both"]:
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
actions = parse_actions_from_string(response)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_string(response)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
self.actions.append(actions)
|
|
|
|
return actions
|
|
elif self.exp in ["som", "seeact"]:
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_som_string(response, masks)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
self.actions.append(actions)
|
|
|
|
return actions
|