import base64 import json import os import re import uuid from typing import Dict, List import backoff import requests from openai.error import ( APIConnectionError, APIError, RateLimitError, ServiceUnavailableError, InvalidRequestError ) from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT import logging logger = logging.getLogger("desktopenv.agent") # Function to encode the image def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') def linearize_accessibility_tree(accessibility_tree): leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(leaf_nodes) linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: linearized_accessibility_tree += node.tag + "\t" linearized_accessibility_tree += node.attrib.get('name') + "\t" if node.text: linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(node.text.replace('"', '""'))) + "\t" elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper")\ and node.get("{uri:deskat:value.at-spi.gnome.org}value"): text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value") linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(text.replace('"', '""'))) + "\t" else: linearized_accessibility_tree += '""\t' linearized_accessibility_tree += node.attrib.get( '{uri:deskat:component.at-spi.gnome.org}screencoord', "") + "\t" linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" return linearized_accessibility_tree def tag_screenshot(screenshot, accessibility_tree): # Creat a tmp file to store the screenshot in random name uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) # Make tag screenshot marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) return marks, drew_nodes, tagged_screenshot_file_path def parse_actions_from_string(input_string): if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] # Search for a JSON string within the input string actions = [] matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: try: action_dict = json.loads(input_string) return [action_dict] except json.JSONDecodeError: raise ValueError("Invalid response format: " + input_string) def parse_code_from_string(input_string): if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] # This regular expression will match both ```code``` and ```python code``` # and capture the `code` part. It uses a non-greedy match for the content inside. pattern = r"```(?:\w+\s+)?(.*?)```" # Find all non-overlapping matches in the string matches = re.findall(pattern, input_string, re.DOTALL) # The regex above captures the content inside the triple backticks. # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, # so the code inside backticks can span multiple lines. # matches now contains all the captured code snippets codes = [] for match in matches: match = match.strip() commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands if match in commands: codes.append(match.strip()) elif match.split('\n')[-1] in commands: if len(match.split('\n')) > 1: codes.append("\n".join(match.split('\n')[:-1])) codes.append(match.split('\n')[-1]) else: codes.append(match) return codes def parse_code_from_som_string(input_string, masks): # parse the output string by masks mappings = [] for i, mask in enumerate(masks): x, y, w, h = mask mappings.append(("tag#" + str(i + 1), "{}, {}".format(int(x + w // 2), int(y + h // 2)))) # reverse the mappings for mapping in mappings[::-1]: input_string = input_string.replace(mapping[0], mapping[1]) actions = parse_code_from_string(input_string) return actions class GPT4v_Agent: def __init__( self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=500, action_space="computer_13", exp="screenshot_a11y_tree" # exp can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] ): self.instruction = instruction self.model = model self.max_tokens = max_tokens self.action_space = action_space self.exp = exp self.max_trajectory_length = 3 self.headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" } self.thoughts = [] self.actions = [] self.observations = [] if exp == "screenshot": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif exp == "a11y_tree": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif exp == "both": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif exp == "som": if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG else: raise ValueError("Invalid action space: " + action_space) elif exp == "seeact": if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": self.system_message = SYS_PROMPT_SEEACT else: raise ValueError("Invalid action space: " + action_space) else: raise ValueError("Invalid experiment type: " + exp) self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( self.instruction) def predict(self, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. """ # Prepare the payload for the API call messages = [] masks = None messages.append({ "role": "system", "content": [ { "type": "text", "text": self.system_message }, ] }) # Append trajectory assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \ , "The number of observations and actions should be the same." if len(self.observations) > self.max_trajectory_length: _observations = self.observations[-self.max_trajectory_length:] _actions = self.actions[-self.max_trajectory_length:] _thoughts = self.thoughts[-self.max_trajectory_length:] else: _observations = self.observations _actions = self.actions _thoughts = self.thoughts for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): # {{{1 if self.exp == "both": _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{_screenshot}", "detail": "high" } } ] }) elif self.exp in ["som", "seeact"]: _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{_screenshot}", "detail": "high" } } ] }) elif self.exp == "screenshot": _screenshot = previous_obs["screenshot"] messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{_screenshot}", "detail": "high" } } ] }) elif self.exp == "a11y_tree": _linearized_accessibility_tree = previous_obs["accessibility_tree"] messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) } ] }) else: raise ValueError("Invalid experiment type: " + self.exp) # 1}}} messages.append({ "role": "assistant", "content": [ { "type": "text", "text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action" }, ] }) # {{{1 if self.exp in ["screenshot", "both"]: base64_image = encode_image(obs["screenshot"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.exp == "both": self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) else: self.observations.append({ "screenshot": base64_image, "accessibility_tree": None }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot as below. What's the next step that you will do to help with the task?" if self.exp == "screenshot" else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" } } ] }) elif self.exp == "a11y_tree": linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": None, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) } ] }) elif self.exp == "som": # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" } } ] }) elif self.exp == "seeact": # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" } } ] }) else: raise ValueError("Invalid experiment type: " + self.exp) # 1}}} with open("messages.json", "w") as f: f.write(json.dumps(messages, indent=4)) try: response = self.call_llm({ "model": self.model, "messages": messages, "max_tokens": self.max_tokens }) except: response = "" logger.debug("RESPONSE: %s", response) # {{{ if self.exp == "seeact": messages.append({ "role": "assistant", "content": [ { "type": "text", "text": response } ] }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "{}\n\nWhat's the next step that you will do to help with the task?".format( ACTION_GROUNDING_PROMPT_SEEACT) } ] }) response = self.call_llm({ "model": self.model, "messages": messages, "max_tokens": self.max_tokens }) print(response) try: actions = self.parse_actions(response, masks) self.thoughts.append(response) except Exception as e: print("Failed to parse action from response", e) actions = None self.thoughts.append("") # }}} return actions @backoff.on_exception( backoff.expo, (APIError, RateLimitError, APIConnectionError, ServiceUnavailableError, InvalidRequestError), max_tries=3 ) def call_llm(self, payload): response = requests.post( "https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload, timeout=20 ) if response.status_code != 200: if response.json()['error']['code'] == "context_length_exceeded": print("Context length exceeded. Retrying with a smaller context.") payload["messages"] = payload["messages"][-1:] retry_response = requests.post( "https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload ) if retry_response.status_code != 200: print("Failed to call LLM: " + retry_response.text) return "" print("Failed to call LLM: " + response.text) return "" else: return response.json()['choices'][0]['message']['content'] def parse_actions(self, response: str, masks=None): if self.exp in ["screenshot", "a11y_tree", "both"]: # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) else: raise ValueError("Invalid action space: " + self.action_space) self.actions.append(actions) return actions elif self.exp in ["som", "seeact"]: # parse from the response if self.action_space == "computer_13": raise ValueError("Invalid action space: " + self.action_space) elif self.action_space == "pyautogui": actions = parse_code_from_som_string(response, masks) else: raise ValueError("Invalid action space: " + self.action_space) self.actions.append(actions) return actions