196 lines
6.9 KiB
Python
196 lines
6.9 KiB
Python
import base64
|
|
import json
|
|
import re
|
|
import time
|
|
from typing import Dict, List
|
|
|
|
import requests
|
|
|
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
|
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
|
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
|
|
|
|
|
# Function to encode the image
|
|
def encode_image(image_path):
|
|
with open(image_path, "rb") as image_file:
|
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
|
|
|
|
|
def parse_actions_from_string(input_string):
|
|
# Search for a JSON string within the input string
|
|
actions = []
|
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
try:
|
|
action_dict = json.loads(input_string)
|
|
return [action_dict]
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError("Invalid response format: " + input_string)
|
|
|
|
|
|
def parse_code_from_string(input_string):
|
|
# This regular expression will match both ```code``` and ```python code```
|
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
|
# Find all non-overlapping matches in the string
|
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
|
|
# The regex above captures the content inside the triple backticks.
|
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
|
# so the code inside backticks can span multiple lines.
|
|
|
|
# matches now contains all the captured code snippets
|
|
|
|
codes = []
|
|
|
|
for match in matches:
|
|
match = match.strip()
|
|
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
|
|
|
|
if match in commands:
|
|
codes.append(match.strip())
|
|
elif match.split('\n')[-1] in commands:
|
|
if len(match.split('\n')) > 1:
|
|
codes.append("\n".join(match.split('\n')[:-1]))
|
|
codes.append(match.split('\n')[-1])
|
|
else:
|
|
codes.append(match)
|
|
|
|
return codes
|
|
|
|
|
|
class GPT4_Agent:
|
|
def __init__(self, api_key, instruction, model="gpt-4-1106-preview", max_tokens=600, action_space="computer_13"):
|
|
self.instruction = instruction
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.action_space = action_space
|
|
|
|
self.headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}"
|
|
}
|
|
|
|
self.trajectory = [
|
|
{
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": {
|
|
"computer_13": SYS_PROMPT_ACTION,
|
|
"pyautogui": SYS_PROMPT_CODE
|
|
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
|
},
|
|
]
|
|
}
|
|
]
|
|
|
|
def predict(self, obs: Dict) -> List:
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
"""
|
|
accessibility_tree = obs["accessibility_tree"]
|
|
|
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
|
filtered_nodes = filter_nodes(leaf_nodes)
|
|
|
|
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
|
# Linearize the accessibility tree nodes into a table format
|
|
|
|
for node in filtered_nodes:
|
|
linearized_accessibility_tree += node.tag + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
|
linearized_accessibility_tree += node.attrib.get(
|
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
|
|
|
self.trajectory.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
}
|
|
]
|
|
})
|
|
|
|
# print(
|
|
# "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
# linearized_accessibility_tree)
|
|
# )
|
|
|
|
traj_to_show = []
|
|
for i in range(len(self.trajectory)):
|
|
traj_to_show.append(self.trajectory[i]["content"][0]["text"])
|
|
if len(self.trajectory[i]["content"]) > 1:
|
|
traj_to_show.append("screenshot_obs")
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": self.trajectory,
|
|
"max_tokens": self.max_tokens
|
|
}
|
|
|
|
while True:
|
|
try:
|
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers,
|
|
json=payload)
|
|
break
|
|
except:
|
|
print("Failed to generate response, retrying...")
|
|
time.sleep(5)
|
|
pass
|
|
|
|
try:
|
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
|
except:
|
|
print("Failed to parse action from response:", response.json())
|
|
actions = None
|
|
|
|
return actions
|
|
|
|
def parse_actions(self, response: str):
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
actions = parse_actions_from_string(response)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_string(response)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
# add action into the trajectory
|
|
self.trajectory.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": response
|
|
},
|
|
]
|
|
})
|
|
|
|
return actions
|