Initialize all baselines: screenshot, a11y tree, both, SoM, SeeAct

This commit is contained in:
Timothyxxx
2024-01-20 00:13:46 +08:00
parent 46bd3386dd
commit 09f3e776ae
14 changed files with 2588 additions and 1208 deletions

View File

@@ -1,14 +1,27 @@
import base64
import json
import os
import re
import time
import uuid
from typing import Dict, List
import backoff
import requests
from openai.error import (
APIConnectionError,
APIError,
RateLimitError,
ServiceUnavailableError,
InvalidRequestError
)
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \
SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT
# Function to encode the image
@@ -17,6 +30,35 @@ def encode_image(image_path):
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree):
leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(leaf_nodes)
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
linearized_accessibility_tree += node.attrib.get(
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
return linearized_accessibility_tree
def tag_screenshot(screenshot, accessibility_tree):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
# Make tag screenshot
marks = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
return marks, tagged_screenshot_file_path
def parse_actions_from_string(input_string):
# Search for a JSON string within the input string
actions = []
@@ -61,124 +103,295 @@ def parse_code_from_string(input_string):
# so the code inside backticks can span multiple lines.
# matches now contains all the captured code snippets
return matches
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def parse_code_from_som_string(input_string, masks):
for i, mask in enumerate(masks):
x, y, w, h = mask
input_string = input_string.replace("tag#" + str(i), "{}, {}".format(int(x + w // 2), int(y + h // 2)))
return parse_code_from_string(input_string)
class GPT4v_Agent:
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13", add_a11y_tree=False):
def __init__(
self,
api_key,
instruction,
model="gpt-4-vision-preview",
max_tokens=300,
action_space="computer_13",
exp="screenshot_a11y_tree"
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
):
self.instruction = instruction
self.model = model
self.max_tokens = max_tokens
self.action_space = action_space
self.add_a11y_tree = add_a11y_tree
self.exp = exp
self.headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
self.trajectory = [
{
"role": "system",
"content": [
{
"type": "text",
"text": {
"computer_13": SYS_PROMPT_ACTION,
"pyautogui": SYS_PROMPT_CODE
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
},
]
}
]
self.actions = []
self.observations = []
if exp == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "both":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
elif exp == "seeact":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_SEEACT
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + exp)
self.system_message = (self.system_message +
"\nHere is the instruction for the task: {}".format(self.instruction))
def predict(self, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
base64_image = encode_image(obs["screenshot"])
accessibility_tree = obs["accessibility_tree"]
leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(leaf_nodes)
# Prepare the payload for the API call
messages = []
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
# Linearize the accessibility tree nodes into a table format
if len(self.actions) > 0:
system_message = self.system_message + "\nHere are the actions you have done so far:\n" + "\n->\n".join(
self.actions)
else:
system_message = self.system_message
for node in filtered_nodes:
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
linearized_accessibility_tree += node.attrib.get(
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
self.trajectory.append({
"role": "user",
messages.append({
"role": "system",
"content": [
{
"type": "text",
"text": "What's the next step that you will do to help with the task?" if not self.add_a11y_tree
else "And given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(linearized_accessibility_tree)
"text": system_message
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
]
})
traj_to_show = []
for i in range(len(self.trajectory)):
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
if len(self.trajectory[i]["content"]) > 1:
traj_to_show.append("screenshot_obs")
masks = None
print("Trajectory:", traj_to_show)
if self.exp in ["screenshot", "both"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
if self.exp == "screenshot"
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.exp == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
}
]
})
elif self.exp == "som":
# Add som to the screenshot
masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
payload = {
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.exp == "seeact":
# Add som to the screenshot
masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}",
"detail": "high"
}
}
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp)
response = self.call_llm({
"model": self.model,
"messages": self.trajectory,
"messages": messages,
"max_tokens": self.max_tokens
}
})
if self.exp == "seeact":
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
}
]
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "{}\n\nWhat's the next step that you will do to help with the task?".format(
ACTION_GROUNDING_PROMPT_SEEACT)
}
]
})
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens
})
try:
actions = self.parse_actions(response, masks)
except Exception as e:
print("Failed to parse action from response", e)
actions = None
return actions
@backoff.on_exception(
backoff.expo,
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
)
def call_llm(self, payload):
while True:
try:
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers,
json=payload)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
break
except:
print("Failed to generate response, retrying...")
time.sleep(5)
pass
try:
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
except:
print("Failed to parse action from response:", response.json())
actions = None
return actions
return response.json()['choices'][0]['message']['content']
def parse_actions(self, response: str):
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
def parse_actions(self, response: str, masks=None):
# add action into the trajectory
self.trajectory.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": response
},
]
})
if self.exp in ["screenshot", "a11y_tree", "both"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
return actions
self.actions.append(actions)
return actions
elif self.exp in ["som", "seeact"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)
elif self.action_space == "pyautogui":
actions = parse_code_from_som_string(response, masks)
else:
raise ValueError("Invalid action space: " + self.action_space)
self.actions.append(actions)
return actions