import base64 import json import logging import os import re import time import uuid import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO from typing import Dict, List import backoff import dashscope import google.generativeai as genai import requests from PIL import Image from vertexai.preview.generative_models import ( HarmBlockThreshold, HarmCategory, Image, ) from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT # todo: cross-check with visualwebarena logger = logging.getLogger("desktopenv.agent") # Function to encode the image def encode_image(image_path): with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode('utf-8') def linearize_accessibility_tree(accessibility_tree): # leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: linearized_accessibility_tree += node.tag + "\t" linearized_accessibility_tree += node.attrib.get('name') + "\t" if node.text: linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format( node.text.replace('"', '""'))) + "\t" elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ and node.get("{uri:deskat:value.at-spi.gnome.org}value"): text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value") linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format( text.replace('"', '""'))) + "\t" else: linearized_accessibility_tree += '""\t' linearized_accessibility_tree += node.attrib.get( '{uri:deskat:component.at-spi.gnome.org}screencoord', "") + "\t" linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" return linearized_accessibility_tree def tag_screenshot(screenshot, accessibility_tree): # Creat a tmp file to store the screenshot in random name uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) # Make tag screenshot marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) return marks, drew_nodes, tagged_screenshot_file_path def parse_actions_from_string(input_string): if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] # Search for a JSON string within the input string actions = [] matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) if matches: # Assuming there's only one match, parse the JSON string into a dictionary try: for match in matches: action_dict = json.loads(match) actions.append(action_dict) return actions except json.JSONDecodeError as e: return f"Failed to parse JSON: {e}" else: try: action_dict = json.loads(input_string) return [action_dict] except json.JSONDecodeError: raise ValueError("Invalid response format: " + input_string) def parse_code_from_string(input_string): input_string = input_string.replace(";", "\n") if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: return [input_string.strip()] # This regular expression will match both ```code``` and ```python code``` # and capture the `code` part. It uses a non-greedy match for the content inside. pattern = r"```(?:\w+\s+)?(.*?)```" # Find all non-overlapping matches in the string matches = re.findall(pattern, input_string, re.DOTALL) # The regex above captures the content inside the triple backticks. # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, # so the code inside backticks can span multiple lines. # matches now contains all the captured code snippets codes = [] for match in matches: match = match.strip() commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands if match in commands: codes.append(match.strip()) elif match.split('\n')[-1] in commands: if len(match.split('\n')) > 1: codes.append("\n".join(match.split('\n')[:-1])) codes.append(match.split('\n')[-1]) else: codes.append(match) return codes def parse_code_from_som_string(input_string, masks): # parse the output string by masks tag_vars = "" for i, mask in enumerate(masks): x, y, w, h = mask tag_vars += "tag_" + str(i + 1) + "=" + "({}, {})".format(int(x + w // 2), int(y + h // 2)) tag_vars += "\n" actions = parse_code_from_string(input_string) for i, action in enumerate(actions): if action.strip() in ['WAIT', 'DONE', 'FAIL']: pass else: action = tag_vars + action actions[i] = action return actions class PromptAgent: def __init__( self, model="gpt-4-vision-preview", max_tokens=1500, top_p=0.9, temperature=0.5, action_space="computer_13", observation_type="screenshot_a11y_tree", # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] max_trajectory_length=3 ): self.model = model self.max_tokens = max_tokens self.top_p = top_p self.temperature = temperature self.action_space = action_space self.observation_type = observation_type self.max_trajectory_length = max_trajectory_length self.thoughts = [] self.actions = [] self.observations = [] if observation_type == "screenshot": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif observation_type == "a11y_tree": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif observation_type == "both": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) elif observation_type == "som": if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG else: raise ValueError("Invalid action space: " + action_space) elif observation_type == "seeact": if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": self.system_message = SYS_PROMPT_SEEACT else: raise ValueError("Invalid action space: " + action_space) else: raise ValueError("Invalid experiment type: " + observation_type) def predict(self, instruction: str, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. """ self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( instruction) # Prepare the payload for the API call messages = [] masks = None messages.append({ "role": "system", "content": [ { "type": "text", "text": self.system_message }, ] }) # Append trajectory assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \ , "The number of observations and actions should be the same." if len(self.observations) > self.max_trajectory_length: _observations = self.observations[-self.max_trajectory_length:] _actions = self.actions[-self.max_trajectory_length:] _thoughts = self.thoughts[-self.max_trajectory_length:] else: _observations = self.observations _actions = self.actions _thoughts = self.thoughts for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): # {{{1 if self.observation_type == "both": _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{_screenshot}", "detail": "high" } } ] }) elif self.observation_type in ["som", "seeact"]: _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{_screenshot}", "detail": "high" } } ] }) elif self.observation_type == "screenshot": _screenshot = previous_obs["screenshot"] messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{_screenshot}", "detail": "high" } } ] }) elif self.observation_type == "a11y_tree": _linearized_accessibility_tree = previous_obs["accessibility_tree"] messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( _linearized_accessibility_tree) } ] }) else: raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}} messages.append({ "role": "assistant", "content": [ { "type": "text", "text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action" }, ] }) # {{{1 if self.observation_type in ["screenshot", "both"]: base64_image = encode_image(obs["screenshot"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "both": self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) else: self.observations.append({ "screenshot": base64_image, "accessibility_tree": None }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the screenshot as below. What's the next step that you will do to help with the task?" if self.observation_type == "screenshot" else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{base64_image}", "detail": "high" } } ] }) elif self.observation_type == "a11y_tree": linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": None, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) } ] }) elif self.observation_type == "som": # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{base64_image}", "detail": "high" } } ] }) elif self.observation_type == "seeact": # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree }) messages.append({ "role": "user", "content": [ { "type": "text", "text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree) }, { "type": "image_url", "image_url": { "url": f"data:image/png;base64,{base64_image}", "detail": "high" } } ] }) else: raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}} # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) logger.info("Generating content with GPT model: %s", self.model) response = self.call_llm({ "model": self.model, "messages": messages, "max_tokens": self.max_tokens }) logger.info("RESPONSE: %s", response) if self.observation_type == "seeact": messages.append({ "role": "assistant", "content": [ { "type": "text", "text": response } ] }) messages.append({ "role": "user", "content": [ { "type": "text", "text": "{}\n\nWhat's the next step that you will do to help with the task?".format( ACTION_GROUNDING_PROMPT_SEEACT) } ] }) logger.info("Generating content with GPT model: %s", self.model) response = self.call_llm({ "model": self.model, "messages": messages, "max_tokens": self.max_tokens, "top_p": self.top_p, "temperature": self.temperature }) logger.info("RESPONSE: %s", response) try: actions = self.parse_actions(response, masks) self.thoughts.append(response) except Exception as e: print("Failed to parse action from response", e) actions = None self.thoughts.append("") return actions @backoff.on_exception( backoff.expo, (Exception), max_tries=5 ) def call_llm(self, payload): if self.model.startswith("gpt"): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" } logger.info("Generating content with GPT model: %s", self.model) response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) if response.status_code != 200: if response.json()['error']['code'] == "context_length_exceeded": logger.error("Context length exceeded. Retrying with a smaller context.") payload["messages"] = payload["messages"][-1:] retry_response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) if retry_response.status_code != 200: logger.error("Failed to call LLM: " + retry_response.text) return "" logger.error("Failed to call LLM: " + response.text) time.sleep(5) return "" else: return response.json()['choices'][0]['message']['content'] # elif self.model.startswith("mistral"): # print("Call mistral") # messages = payload["messages"] # max_tokens = payload["max_tokens"] # # misrtal_messages = [] # # for i, message in enumerate(messages): # mistral_message = { # "role": message["role"], # "content": [] # } # # for part in message["content"]: # mistral_message['content'] = part['text'] if part['type'] == "text" else None # # misrtal_messages.append(mistral_message) # # # the mistral not support system message in our endpoint, so we concatenate it at the first user message # if misrtal_messages[0]['role'] == "system": # misrtal_messages[1]['content'] = misrtal_messages[0]['content'] + "\n" + misrtal_messages[1]['content'] # misrtal_messages.pop(0) # # # openai.api_base = "http://localhost:8000/v1" # # openai.api_key = "test" # # response = openai.ChatCompletion.create( # # messages=misrtal_messages, # # model="Mixtral-8x7B-Instruct-v0.1" # # ) # # from openai import OpenAI # TOGETHER_API_KEY = "d011650e7537797148fb6170ec1e0be7ae75160375686fae02277136078e90d2" # # client = OpenAI(api_key=TOGETHER_API_KEY, # base_url='https://api.together.xyz', # ) # logger.info("Generating content with Mistral model: %s", self.model) # response = client.chat.completions.create( # messages=misrtal_messages, # model="mistralai/Mixtral-8x7B-Instruct-v0.1", # max_tokens=1024 # ) # # try: # # return response['choices'][0]['message']['content'] # return response.choices[0].message.content # except Exception as e: # print("Failed to call LLM: " + str(e)) # return "" elif self.model.startswith("gemini"): def encoded_img_to_pil_img(data_str): base64_str = data_str.replace("data:image/png;base64,", "") image_data = base64.b64decode(base64_str) image = Image.open(BytesIO(image_data)) return image messages = payload["messages"] max_tokens = payload["max_tokens"] top_p = payload["top_p"] temperature = payload["temperature"] gemini_messages = [] for i, message in enumerate(messages): role_mapping = { "assistant": "model", "user": "user", "system": "system" } gemini_message = { "role": role_mapping[message["role"]], "parts": [] } assert len(message["content"]) in [1, 2], "One text, or one text with one image" # The gemini only support the last image as single image input if i == len(messages) - 1: for part in message["content"]: gemini_message['parts'].append(part['text']) if part['type'] == "text" \ else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url'])) else: for part in message["content"]: gemini_message['parts'].append(part['text']) if part['type'] == "text" else None gemini_messages.append(gemini_message) # the mistral not support system message in our endpoint, so we concatenate it at the first user message if gemini_messages[0]['role'] == "system": gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0] gemini_messages.pop(0) # since the gemini-pro-vision donnot support multi-turn message if self.model == "gemini-pro-vision": message_history_str = "" for message in gemini_messages: message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] print(gemini_messages) api_key = os.environ.get("GENAI_API_KEY") assert api_key is not None, "Please set the GENAI_API_KEY environment variable" genai.configure(api_key=api_key) logger.info("Generating content with Gemini model: %s", self.model) response = genai.GenerativeModel(self.model).generate_content( gemini_messages, generation_config={ "candidate_count": 1, "max_output_tokens": max_tokens, "top_p": top_p, "temperature": temperature }, safety_settings={ HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, } ) try: return response.text except Exception as e: return "" elif self.model.startswith("qwen"): messages = payload["messages"] max_tokens = payload["max_tokens"] top_p = payload["top_p"] temperature = payload["temperature"] qwen_messages = [] for i, message in enumerate(messages): qwen_message = { "role": message["role"], "content": [] } assert len(message["content"]) in [1, 2], "One text, or one text with one image" for part in message["content"]: qwen_message['content'].append({"image": part['image_url']['url']}) if part[ 'type'] == "image_url" else None qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None qwen_messages.append(qwen_message) response = dashscope.MultiModalConversation.call( model='qwen-vl-plus', messages=messages, # todo: add the hyperparameters ) # The response status_code is HTTPStatus.OK indicate success, # otherwise indicate request is failed, you can get error code # and message from code and message. if response.status_code == HTTPStatus.OK: try: return response.json()['output']['choices'][0]['message']['content'] except Exception as e: return "" else: print(response.code) # The error code. print(response.message) # The error message. return "" else: raise ValueError("Invalid model: " + self.model) def parse_actions(self, response: str, masks=None): if self.observation_type in ["screenshot", "a11y_tree", "both"]: # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) else: raise ValueError("Invalid action space: " + self.action_space) self.actions.append(actions) return actions elif self.observation_type in ["som", "seeact"]: # parse from the response if self.action_space == "computer_13": raise ValueError("Invalid action space: " + self.action_space) elif self.action_space == "pyautogui": actions = parse_code_from_som_string(response, masks) else: raise ValueError("Invalid action space: " + self.action_space) self.actions.append(actions) return actions def reset(self): self.thoughts = [] self.actions = [] self.observations = []