From 172123ab2c229a57638d8c3b4b4397f5ce3f05fa Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Mon, 25 Mar 2024 18:02:48 +0800 Subject: [PATCH] Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim --- desktop_env/envs/desktop_env.py | 8 +- desktop_env/server/main.py | 40 ++++++-- .../heuristic_retrieve.py | 6 +- mm_agents/agent.py | 93 ++++++++++++------- 4 files changed, 104 insertions(+), 43 deletions(-) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index b443a4a..5fd972d 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -146,7 +146,13 @@ class DesktopEnv(gym.Env): image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no)) # Get the screenshot and save to the image_path - screenshot = self.controller.get_screenshot() + max_retries = 20 + for _ in range(max_retries): + screenshot = self.controller.get_screenshot() + if screenshot is not None: + break + time.sleep(1) + with open(image_path, "wb") as f: f.write(screenshot) diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index 8e900a3..cd6998c 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str # Value {{{ # if hasattr(node, "get_step"): - attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step()) + try: + attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step()) + except: + pass if hasattr(node, "value"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value()) + except: + pass if hasattr(node, "get_value"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value()) + except: + pass elif hasattr(node, "get_position"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position()) + except: + pass if hasattr(node, "min_value"): - attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value()) + try: + attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value()) + except: + pass elif hasattr(node, "get_range_min"): - attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min()) + try: + attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min()) + except: + pass if hasattr(node, "max_value"): - attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value()) + try: + attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value()) + except: + pass elif hasattr(node, "get_range_max"): - attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max()) + try: + attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max()) + except: + pass # }}} Value # attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node)) diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index e2845f3..5c7b830 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -110,6 +110,10 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam coords = tuple(map(int, coords_str.strip('()').split(', '))) size = tuple(map(int, size_str.strip('()').split(', '))) + import copy + original_coords = copy.deepcopy(coords) + original_size = copy.deepcopy(size) + if float(down_sampling_ratio) != 1.0: # Downsample the coordinates and size coords = tuple(int(coord * down_sampling_ratio) for coord in coords) @@ -145,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam draw.text(text_position, str(index), font=font, anchor="lb", fill="white") # each mark is an x, y, w, h tuple - marks.append([coords[0], coords[1], size[0], size[1]]) + marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]]) drew_nodes.append(_node) if _node.text: diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 4b27968..5ed3d27 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -8,12 +8,14 @@ import uuid import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO -from typing import Dict, List, Tuple, Union +from typing import Dict, List + import backoff import dashscope import google.generativeai as genai import openai import requests +import tiktoken from PIL import Image from google.api_core.exceptions import InvalidArgument @@ -32,49 +34,49 @@ def encode_image(image_path): return base64.b64encode(image_file.read()).decode('utf-8') -def linearize_accessibility_tree(accessibility_tree): +def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): # leaf_nodes = find_leaf_nodes(accessibility_tree) - filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) + filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"] # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: - #linearized_accessibility_tree += node.tag + "\t" - #linearized_accessibility_tree += node.attrib.get('name') + "\t" + # linearized_accessibility_tree += node.tag + "\t" + # linearized_accessibility_tree += node.attrib.get('name') + "\t" if node.text: - text = ( node.text if '"' not in node.text\ - else '"{:}"'.format(node.text.replace('"', '""')) - ) + text = (node.text if '"' not in node.text \ + else '"{:}"'.format(node.text.replace('"', '""')) + ) elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ and node.get("{uri:deskat:value.at-spi.gnome.org}value"): text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value") - text = (text if '"' not in text\ - else '"{:}"'.format(text.replace('"', '""')) - ) + text = (text if '"' not in text \ + else '"{:}"'.format(text.replace('"', '""')) + ) else: text = '""' - #linearized_accessibility_tree += node.attrib.get( - #, "") + "\t" - #linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" + # linearized_accessibility_tree += node.attrib.get( + # , "") + "\t" + # linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" linearized_accessibility_tree.append( - "{:}\t{:}\t{:}\t{:}\t{:}".format( - node.tag, node.get("name", ""), text - , node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "") - , node.get('{uri:deskat:component.at-spi.gnome.org}size', "") - ) - ) + "{:}\t{:}\t{:}\t{:}\t{:}".format( + node.tag, node.get("name", ""), text + , node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "") + , node.get('{uri:deskat:component.at-spi.gnome.org}size', "") + ) + ) return "\n".join(linearized_accessibility_tree) -def tag_screenshot(screenshot, accessibility_tree): +def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"): # Creat a tmp file to store the screenshot in random name uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) - nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True) + nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True) # Make tag screenshot marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) @@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks): return actions +def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): + enc = tiktoken.encoding_for_model("gpt-4") + tokens = enc.encode(linearized_accessibility_tree) + if len(tokens) > max_tokens: + linearized_accessibility_tree = enc.decode(tokens[:max_tokens]) + linearized_accessibility_tree += "[...]\n" + return linearized_accessibility_tree + class PromptAgent: def __init__( self, + platform="ubuntu", model="gpt-4-vision-preview", max_tokens=1500, top_p=0.9, @@ -180,8 +191,10 @@ class PromptAgent: action_space="computer_13", observation_type="screenshot_a11y_tree", # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] - max_trajectory_length=3 + max_trajectory_length=3, + a11y_tree_max_tokens=10000 ): + self.platform = platform self.model = model self.max_tokens = max_tokens self.top_p = top_p @@ -189,6 +202,7 @@ class PromptAgent: self.action_space = action_space self.observation_type = observation_type self.max_trajectory_length = max_trajectory_length + self.a11y_tree_max_tokens = a11y_tree_max_tokens self.thoughts = [] self.actions = [] @@ -349,9 +363,14 @@ class PromptAgent: # {{{1 if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: base64_image = encode_image(obs["screenshot"]) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"], + platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + if self.observation_type == "screenshot_a11y_tree": self.observations.append({ "screenshot": base64_image, @@ -383,9 +402,14 @@ class PromptAgent: ] }) elif self.observation_type == "a11y_tree": - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"], + platform=self.platform) logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + self.observations.append({ "screenshot": None, "accessibility_tree": linearized_accessibility_tree @@ -403,10 +427,15 @@ class PromptAgent: }) elif self.observation_type == "som": # Add som to the screenshot - masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) + masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[ + "accessibility_tree"], self.platform) base64_image = encode_image(tagged_screenshot) logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree @@ -435,7 +464,7 @@ class PromptAgent: # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) - #logger.info("PROMPT: %s", messages) + # logger.info("PROMPT: %s", messages) response = self.call_llm({ "model": self.model, @@ -556,8 +585,6 @@ class PromptAgent: "Content-Type": "application/json" } - - payload = { "model": self.model, "max_tokens": max_tokens, @@ -570,7 +597,8 @@ class PromptAgent: attempt = 0 while attempt < max_attempts: # response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload) - response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload) + response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, + json=payload) if response.status_code == 200: result = response.json()['choices'][0]['message']['content'] break @@ -581,7 +609,7 @@ class PromptAgent: else: print("Exceeded maximum attempts to call LLM.") result = "" - + return result @@ -605,14 +633,13 @@ class PromptAgent: mistral_messages.append(mistral_message) - from openai import OpenAI client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], base_url='https://api.together.xyz', ) logger.info("Generating content with Mistral model: %s", self.model) - + flag = 0 while True: try: