Refactor baselines code implementations
This commit is contained in:
@@ -2,7 +2,6 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
@@ -54,9 +53,9 @@ def tag_screenshot(screenshot, accessibility_tree):
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
# Make tag screenshot
|
||||
marks = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
return marks, tagged_screenshot_file_path
|
||||
return marks, drew_nodes, tagged_screenshot_file_path
|
||||
|
||||
|
||||
def parse_actions_from_string(input_string):
|
||||
@@ -123,11 +122,18 @@ def parse_code_from_string(input_string):
|
||||
|
||||
|
||||
def parse_code_from_som_string(input_string, masks):
|
||||
# parse the output string by masks
|
||||
mappings = []
|
||||
for i, mask in enumerate(masks):
|
||||
x, y, w, h = mask
|
||||
input_string = input_string.replace("tag#" + str(i), "{}, {}".format(int(x + w // 2), int(y + h // 2)))
|
||||
mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2))))
|
||||
|
||||
return parse_code_from_string(input_string)
|
||||
# reverse the mappings
|
||||
for mapping in mappings[::-1]:
|
||||
input_string = input_string.replace(mapping[0], mapping[1])
|
||||
|
||||
actions = parse_code_from_string(input_string)
|
||||
return actions
|
||||
|
||||
|
||||
class GPT4v_Agent:
|
||||
@@ -136,7 +142,7 @@ class GPT4v_Agent:
|
||||
api_key,
|
||||
instruction,
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=300,
|
||||
max_tokens=500,
|
||||
action_space="computer_13",
|
||||
exp="screenshot_a11y_tree"
|
||||
# exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"]
|
||||
@@ -147,6 +153,7 @@ class GPT4v_Agent:
|
||||
self.max_tokens = max_tokens
|
||||
self.action_space = action_space
|
||||
self.exp = exp
|
||||
self.max_trajectory_length = 3
|
||||
|
||||
self.headers = {
|
||||
"Content-Type": "application/json",
|
||||
@@ -194,8 +201,8 @@ class GPT4v_Agent:
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + exp)
|
||||
|
||||
self.system_message = (self.system_message +
|
||||
"\nHere is the instruction for the task: {}".format(self.instruction))
|
||||
self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(
|
||||
self.instruction)
|
||||
|
||||
def predict(self, obs: Dict) -> List:
|
||||
"""
|
||||
@@ -204,28 +211,111 @@ class GPT4v_Agent:
|
||||
|
||||
# Prepare the payload for the API call
|
||||
messages = []
|
||||
|
||||
if len(self.actions) > 0:
|
||||
system_message = self.system_message + "\nHere are the actions you have done so far:\n" + "\n->\n".join(
|
||||
self.actions)
|
||||
else:
|
||||
system_message = self.system_message
|
||||
masks = None
|
||||
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_message
|
||||
"text": self.system_message
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
masks = None
|
||||
# Append trajectory
|
||||
assert len(self.observations) == len(self.actions), "The number of observations and actions should be the same."
|
||||
|
||||
if len(self.observations) > self.max_trajectory_length:
|
||||
_observations = self.observations[-self.max_trajectory_length:]
|
||||
_actions = self.actions[-self.max_trajectory_length:]
|
||||
else:
|
||||
_observations = self.observations
|
||||
_actions = self.actions
|
||||
|
||||
for previous_obs, previous_action in zip(_observations, _actions):
|
||||
|
||||
if self.exp in ["both", "som", "seeact"]:
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the info from the tagged screenshot as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
_linearized_accessibility_tree)
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "screenshot":
|
||||
_screenshot = previous_obs["screenshot"]
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{_screenshot}",
|
||||
"detail": "high"
|
||||
}
|
||||
}
|
||||
]
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||
_linearized_accessibility_tree)
|
||||
}
|
||||
]
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "\n".join(previous_action) if len(previous_action) > 0 else "No valid action"
|
||||
},
|
||||
]
|
||||
})
|
||||
|
||||
if self.exp in ["screenshot", "both"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
if self.exp == "both":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
else:
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": None
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -247,6 +337,12 @@ class GPT4v_Agent:
|
||||
})
|
||||
elif self.exp == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": None,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -259,11 +355,15 @@ class GPT4v_Agent:
|
||||
})
|
||||
elif self.exp == "som":
|
||||
# Add som to the screenshot
|
||||
masks, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
|
||||
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -288,6 +388,11 @@ class GPT4v_Agent:
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
@@ -307,6 +412,9 @@ class GPT4v_Agent:
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + self.exp)
|
||||
|
||||
with open("messages.json", "w") as f:
|
||||
f.write(json.dumps(messages, indent=4))
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -354,20 +462,17 @@ class GPT4v_Agent:
|
||||
(APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError),
|
||||
)
|
||||
def call_llm(self, payload):
|
||||
while True:
|
||||
try:
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
break
|
||||
except:
|
||||
print("Failed to generate response, retrying...")
|
||||
time.sleep(5)
|
||||
pass
|
||||
response = requests.post(
|
||||
"https://api.openai.com/v1/chat/completions",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
)
|
||||
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
if response.status_code != 200:
|
||||
print("Failed to call LLM: " + response.text)
|
||||
return ""
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
|
||||
def parse_actions(self, response: str, masks=None):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user