Refactor baselines code implementations

This commit is contained in:
Timothyxxx
2024-01-20 18:55:21 +08:00
parent 09f3e776ae
commit f88331416c
7 changed files with 204 additions and 65 deletions

View File

@@ -2,7 +2,6 @@ import base64
import json
import os
import re
import time
import uuid
from typing import Dict, List
@@ -54,9 +53,9 @@ def tag_screenshot(screenshot, accessibility_tree):
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
# Make tag screenshot
marks = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
return marks, tagged_screenshot_file_path
return marks, drew_nodes, tagged_screenshot_file_path
def parse_actions_from_string(input_string):
@@ -123,11 +122,18 @@ def parse_code_from_string(input_string):
def parse_code_from_som_string(input_string, masks):
# parse the output string by masks
mappings = []
for i, mask in enumerate(masks):
x, y, w, h = mask
input_string = input_string.replace("tag#" + str(i), "{}, {}".format(int(x + w // 2), int(y + h // 2)))
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
return parse_code_from_string(input_string)
# reverse the mappings
for mapping in mappings[::-1]:
input_string = input_string.replace(mapping[0], mapping[1])
actions = parse_code_from_string(input_string)
return actions
class GPT4v_Agent:
@@ -136,7 +142,7 @@ class GPT4v_Agent:
api_key,
instruction,
model="gpt-4-vision-preview",
max_tokens=300,
max_tokens=500,
action_space="computer_13",
exp="screenshot_a11y_tree"
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
@@ -147,6 +153,7 @@ class GPT4v_Agent:
self.max_tokens = max_tokens
self.action_space = action_space
self.exp = exp
self.max_trajectory_length = 3
self.headers = {
"Content-Type": "application/json",
@@ -194,8 +201,8 @@ class GPT4v_Agent:
else:
raise ValueError("Invalid experiment type: " + exp)
self.system_message = (self.system_message +
"\nHere is the instruction for the task: {}".format(self.instruction))
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
self.instruction)
def predict(self, obs: Dict) -> List:
"""
@@ -204,28 +211,111 @@ class GPT4v_Agent:
# Prepare the payload for the API call
messages = []
if len(self.actions) > 0:
system_message = self.system_message + "\nHere are the actions you have done so far:\n" + "\n->\n".join(
self.actions)
else:
system_message = self.system_message
masks = None
messages.append({
"role": "system",
"content": [
{
"type": "text",
"text": system_message
"text": self.system_message
},
]
})
masks = None
# Append trajectory
assert len(self.observations) == len(self.actions), "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
else:
_observations = self.observations
_actions = self.actions
for previous_obs, previous_action in zip(_observations, _actions):
if self.exp in ["both", "som", "seeact"]:
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.exp == "screenshot":
_screenshot = previous_obs["screenshot"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.exp == "a11y_tree":
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
}
]
})
else:
raise ValueError("Invalid experiment type: " + self.exp)
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": "\n".join(previous_action) if len(previous_action) > 0 else "No valid action"
},
]
})
if self.exp in ["screenshot", "both"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
if self.exp == "both":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
messages.append({
"role": "user",
"content": [
@@ -247,6 +337,12 @@ class GPT4v_Agent:
})
elif self.exp == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": None,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
@@ -259,11 +355,15 @@ class GPT4v_Agent:
})
elif self.exp == "som":
# Add som to the screenshot
masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
@@ -288,6 +388,11 @@ class GPT4v_Agent:
base64_image = encode_image(tagged_screenshot)
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
@@ -307,6 +412,9 @@ class GPT4v_Agent:
else:
raise ValueError("Invalid experiment type: " + self.exp)
with open("messages.json", "w") as f:
f.write(json.dumps(messages, indent=4))
response = self.call_llm({
"model": self.model,
"messages": messages,
@@ -354,20 +462,17 @@ class GPT4v_Agent:
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
)
def call_llm(self, payload):
while True:
try:
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
break
except:
print("Failed to generate response, retrying...")
time.sleep(5)
pass
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=self.headers,
json=payload
)
return response.json()['choices'][0]['message']['content']
if response.status_code != 200:
print("Failed to call LLM: " + response.text)
return ""
else:
return response.json()['choices'][0]['message']['content']
def parse_actions(self, response: str, masks=None):