import base64 import json import re import time from typing import Dict, List import requests from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE # Function to encode the image def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') def parse_actions_from_string(input_string): # Search for a JSON string within the input string actions = [] matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: try: action_dict = json.loads(input_string) return [action_dict] except json.JSONDecodeError as e: raise ValueError("Invalid response format: " + input_string) def parse_code_from_string(input_string): # This regular expression will match both ```code``` and ```python code``` # and capture the `code` part. It uses a non-greedy match for the content inside. pattern = r"```(?:\w+\s+)?(.*?)```" # Find all non-overlapping matches in the string matches = re.findall(pattern, input_string, re.DOTALL) # The regex above captures the content inside the triple backticks. # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, # so the code inside backticks can span multiple lines. # matches now contains all the captured code snippets codes = [] for match in matches: match = match.strip() commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands if match in commands: codes.append(match.strip()) elif match.split('\n')[-1] in commands: if len(match.split('\n')) > 1: codes.append("\n".join(match.split('\n')[:-1])) codes.append(match.split('\n')[-1]) else: codes.append(match) return codes class GPT4_Agent: def __init__(self, api_key, instruction, model="gpt-4-1106-preview", max_tokens=600, action_space="computer_13"): self.instruction = instruction self.model = model self.max_tokens = max_tokens self.action_space = action_space self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } self.trajectory = [ { "role": "system", "content": [ { "type": "text", "text": { "computer_13": SYS_PROMPT_ACTION, "pyautogui": SYS_PROMPT_CODE }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) }, ] } ] def predict(self, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. """ accessibility_tree = obs["accessibility_tree"] leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(leaf_nodes) linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: linearized_accessibility_tree += node.tag + "\t" linearized_accessibility_tree += node.attrib.get('name') + "\t" linearized_accessibility_tree += node.attrib.get( '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" self.trajectory.append({ "role": "user", "content": [ { "type": "text", "text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) } ] }) # print( # "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( # linearized_accessibility_tree) # ) traj_to_show = [] for i in range(len(self.trajectory)): traj_to_show.append(self.trajectory[i]["content"][0]["text"]) if len(self.trajectory[i]["content"]) > 1: traj_to_show.append("screenshot_obs") payload = { "model": self.model, "messages": self.trajectory, "max_tokens": self.max_tokens } while True: try: response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload) break except: print("Failed to generate response, retrying...") time.sleep(5) pass try: actions = self.parse_actions(response.json()['choices'][0]['message']['content']) except: print("Failed to parse action from response:", response.json()) actions = None return actions def parse_actions(self, response: str): # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) else: raise ValueError("Invalid action space: " + self.action_space) # add action into the trajectory self.trajectory.append({ "role": "assistant", "content": [ { "type": "text", "text": response }, ] }) return actions