From 5cbf1b28ca69d2fd5a6c1035615ce56aae19d0ef Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 21:06:50 +0800 Subject: [PATCH 1/8] Fix bugs --- mm_agents/agent.py | 55 ++++++++++++++++++++-------------------------- run.py | 27 +++++++++++++---------- 2 files changed, 39 insertions(+), 43 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 85db78b..4314c63 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -10,16 +10,10 @@ from http import HTTPStatus from io import BytesIO from typing import Dict, List -import backoff import dashscope import google.generativeai as genai import requests from PIL import Image -from vertexai.preview.generative_models import ( - HarmBlockThreshold, - HarmCategory, - Image, -) from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes, draw_bounding_boxes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ @@ -28,8 +22,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT -# todo: cross-check with visualwebarena - logger = logging.getLogger("desktopenv.agent") @@ -43,7 +35,7 @@ def linearize_accessibility_tree(accessibility_tree): # leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) - linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" + linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: @@ -205,7 +197,7 @@ class PromptAgent: self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE else: raise ValueError("Invalid action space: " + action_space) - elif observation_type == "both": + elif observation_type == "screenshot_a11y_tree": if action_space == "computer_13": self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION elif action_space == "pyautogui": @@ -233,8 +225,7 @@ class PromptAgent: """ Predict the next action(s) based on the current observation. """ - self.system_message = self.system_message + "\nYou are asked to complete the following task: {}".format( - instruction) + system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction) # Prepare the payload for the API call messages = [] @@ -245,7 +236,7 @@ class PromptAgent: "content": [ { "type": "text", - "text": self.system_message + "text": system_message }, ] }) @@ -266,7 +257,7 @@ class PromptAgent: for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts): # {{{1 - if self.observation_type == "both": + if self.observation_type == "screenshot_a11y_tree": _screenshot = previous_obs["screenshot"] _linearized_accessibility_tree = previous_obs["accessibility_tree"] logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) @@ -356,11 +347,11 @@ class PromptAgent: }) # {{{1 - if self.observation_type in ["screenshot", "both"]: + if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: base64_image = encode_image(obs["screenshot"]) linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) - if self.observation_type == "both": + if self.observation_type == "screenshot_a11y_tree": self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree @@ -473,7 +464,9 @@ class PromptAgent: response = self.call_llm({ "model": self.model, "messages": messages, - "max_tokens": self.max_tokens + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature }) logger.info("RESPONSE: %s", response) @@ -520,11 +513,11 @@ class PromptAgent: return actions - @backoff.on_exception( - backoff.expo, - (Exception), - max_tries=5 - ) + # @backoff.on_exception( + # backoff.expo, + # (Exception), + # max_tries=5 + # ) def call_llm(self, payload): if self.model.startswith("gpt"): @@ -542,14 +535,14 @@ class PromptAgent: if response.status_code != 200: if response.json()['error']['code'] == "context_length_exceeded": logger.error("Context length exceeded. Retrying with a smaller context.") - payload["messages"] = payload["messages"][-1:] + payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:] retry_response = requests.post( "https://api.openai.com/v1/chat/completions", headers=headers, json=payload ) if retry_response.status_code != 200: - logger.error("Failed to call LLM: " + retry_response.text) + logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text) return "" logger.error("Failed to call LLM: " + response.text) @@ -656,8 +649,9 @@ class PromptAgent: for message in gemini_messages: message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n" gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}] + # gemini_messages[-1]['parts'][1].save("output.png", "PNG") - print(gemini_messages) + # print(gemini_messages) api_key = os.environ.get("GENAI_API_KEY") assert api_key is not None, "Please set the GENAI_API_KEY environment variable" genai.configure(api_key=api_key) @@ -671,11 +665,10 @@ class PromptAgent: "temperature": temperature }, safety_settings={ - HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, + "harassment": "block_none", + "hate": "block_none", + "sex": "block_none", + "danger": "block_none" } ) @@ -726,7 +719,7 @@ class PromptAgent: def parse_actions(self, response: str, masks=None): - if self.observation_type in ["screenshot", "a11y_tree", "both"]: + if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]: # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) diff --git a/run.py b/run.py index 908d479..04aec2c 100644 --- a/run.py +++ b/run.py @@ -66,7 +66,7 @@ def config() -> argparse.Namespace: "screenshot_a11y_tree", "som" ], - default="a11y_tree", + default="som", help="Observation type", ) parser.add_argument("--screen_width", type=int, default=1920) @@ -146,6 +146,7 @@ def test( step_idx = 0 env.controller.start_recording() + # todo: update max running time for each example, @xiaochuan while not done and step_idx < max_steps: actions = agent.predict( instruction, @@ -158,7 +159,7 @@ def test( action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") logger.info("Step %d: %s", step_idx + 1, action) - observation, reward, done, info = env.step(action, args.sleep_after_execution) + obs, reward, done, info = env.step(action, args.sleep_after_execution) logger.info("Reward: %.2f", reward) logger.info("Done: %s", done) @@ -167,7 +168,7 @@ def test( # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), "wb") as _f: - with open(observation['screenshot'], "rb") as __f: + with open(obs['screenshot'], "rb") as __f: screenshot = __f.read() _f.write(screenshot) @@ -186,22 +187,24 @@ def test( if done: logger.info("The episode is done.") break - - result = env.evaluate() + try: + result = env.evaluate() + except Exception as e: + logger.error(f"Error in evaluating the example {example_id}: {e}") + result = 0.0 logger.info("Result: %.2f", result) - scores.append(result) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") -def get_unfinished(test_file_list, result_dir): - finished = [] - for domain in os.listdir(result_dir): - for example_id in os.listdir(os.path.join(result_dir, domain)): - finished.append(f"{domain}/{example_id}") - return [x for x in test_file_list if x not in finished] +def get_unfinished(test, result_dir): + # todo @xiaochuan + pass if __name__ == '__main__': From 51d644c88bcf0a3102bd4a0b79bb09144e8f3aea Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 21:12:18 +0800 Subject: [PATCH 2/8] Merge --- run.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/run.py b/run.py index 953c6b7..c56e142 100644 --- a/run.py +++ b/run.py @@ -6,8 +6,8 @@ import datetime import json import logging import os -import sys import signal +import sys from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent @@ -46,11 +46,15 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") + # make sure each example won't exceed the time limit def handler(signo, frame): raise RuntimeError("Time limit exceeded!") + + signal.signal(signal.SIGALRM, handler) + def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -175,7 +179,7 @@ def test( # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), - "wb") as _f: + "wb") as _f: with open(obs['screenshot'], "rb") as __f: screenshot = __f.read() _f.write(screenshot) @@ -245,6 +249,7 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ return total_file_json + if __name__ == '__main__': ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -253,7 +258,13 @@ if __name__ == '__main__': with open("evaluation_examples/test_all.json", "r", encoding="utf-8") as f: test_all_meta = json.load(f) - test_file_list = get_unfinished(args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta) + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta + ) left_info = "" for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n" From cfa9aaf3a7e3a94b58f1f6a7a72d26aae4c87f08 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 21:16:27 +0800 Subject: [PATCH 3/8] Update README.md --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 8eb867f..6262044 100644 --- a/README.md +++ b/README.md @@ -21,10 +21,12 @@ Please refer to [guidance](https://docs.google.com/document/d/1KBdeZwmZs2Vi_Wsnngb3Wf1-RiwMMpXTftwMqP2Ztak/edit#heading=h.uh0x0tkl7fuw) 2. Install the environment package, download the examples and the virtual machine image. +For x86_64 Linux or Windows, you can install the environment package and download the examples and the virtual machine image by running the following commands: ```bash pip install desktop-env gdown xxxx -gdown xxxx +vmrun -T ws start "Ubuntu/Ubuntu.vmx" nogui +vmrun -T ws snapshot "Ubuntu/Ubuntu.vmx" "init_state" ``` ## Quick Start From 81580a1bbce9e23684fafb18297c44e4eccff115 Mon Sep 17 00:00:00 2001 From: rhythmcao Date: Fri, 15 Mar 2024 22:09:24 +0800 Subject: [PATCH 4/8] fix incompatible errors in main.py (temporarily fixup, will be dropped in future after snapshot download is ok) --- main.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 93282ec..06debec 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,7 @@ import logging import os import sys import time - +import argparse from desktop_env.envs.desktop_env import DesktopEnv # Logger Configs {{{ # @@ -46,19 +46,29 @@ def human_agent(): """ Runs the Gym environment with human input. """ + parser = argparse.ArgumentParser() + parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.") + parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.") + parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") + args = parser.parse_args(sys.argv[1:]) - with open("evaluation_examples/examples/multi_apps/4c26e3f3-3a14-4d86-b44a-d3cedebbb487.json", "r", encoding="utf-8") as f: + example_path = args.example if args.example is not None and os.path.exists(args.example) else \ + 'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json' + with open(example_path, "r", encoding="utf-8") as f: example = json.load(f) - example["snapshot"] = "exp_v5" + if args.snapshot is not None: + example['snapshot'] = args.snapshot + assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." env = DesktopEnv( - path_to_vm=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", - action_space="computer_13", - task_config=example + path_to_vm=args.path, + snapshot_name=args.snapshot, + action_space="computer_13" ) # reset the environment to certain snapshot - observation = env.reset() + observation = env.reset(task_config=example) done = False + logger.info('\x1b[32m[TASK INSTRUCTION]: \x1b[32;3m%s\x1b[0m', example["instruction"]) trajectory = [ { From 1ad4527e8bc0dca55a27e21f3c869906994f306c Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 22:10:35 +0800 Subject: [PATCH 5/8] Change SoM input and output --- mm_agents/agent.py | 104 +++++++------------------------------------ mm_agents/prompts.py | 4 +- 2 files changed, 17 insertions(+), 91 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 7c7b756..f8fc1c8 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -21,10 +21,7 @@ from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ - SYS_PROMPT_IN_SOM_A11Y_OUT_TAG, \ - SYS_PROMPT_SEEACT, ACTION_DESCRIPTION_PROMPT_SEEACT, ACTION_GROUNDING_PROMPT_SEEACT - -# todo: cross-check with visualwebarena + SYS_PROMPT_IN_SOM_OUT_TAG logger = logging.getLogger("desktopenv.agent") @@ -67,7 +64,8 @@ def tag_screenshot(screenshot, accessibility_tree): uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") - nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) + # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) + nodes = filter_nodes(ET.fromstring(accessibility_tree)) # Make tag screenshot marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) @@ -172,7 +170,7 @@ class PromptAgent: temperature=0.5, action_space="computer_13", observation_type="screenshot_a11y_tree", - # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som", "seeact"] + # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] max_trajectory_length=3 ): self.model = model @@ -212,14 +210,7 @@ class PromptAgent: if action_space == "computer_13": raise ValueError("Invalid action space: " + action_space) elif action_space == "pyautogui": - self.system_message = SYS_PROMPT_IN_SOM_A11Y_OUT_TAG - else: - raise ValueError("Invalid action space: " + action_space) - elif observation_type == "seeact": - if action_space == "computer_13": - raise ValueError("Invalid action space: " + action_space) - elif action_space == "pyautogui": - self.system_message = SYS_PROMPT_SEEACT + self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG else: raise ValueError("Invalid action space: " + action_space) else: @@ -283,18 +274,15 @@ class PromptAgent: } ] }) - elif self.observation_type in ["som", "seeact"]: + elif self.observation_type in ["som"]: _screenshot = previous_obs["screenshot"] - _linearized_accessibility_tree = previous_obs["accessibility_tree"] - logger.debug("LINEAR AT: %s", _linearized_accessibility_tree) messages.append({ "role": "user", "content": [ { "type": "text", - "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - _linearized_accessibility_tree) + "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", @@ -407,11 +395,9 @@ class PromptAgent: # Add som to the screenshot masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) base64_image = encode_image(tagged_screenshot) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) self.observations.append({ - "screenshot": base64_image, - "accessibility_tree": linearized_accessibility_tree + "screenshot": base64_image }) messages.append({ @@ -419,35 +405,7 @@ class PromptAgent: "content": [ { "type": "text", - "text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format( - linearized_accessibility_tree) - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64_image}", - "detail": "high" - } - } - ] - }) - elif self.observation_type == "seeact": - # Add som to the screenshot - masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) - base64_image = encode_image(tagged_screenshot) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) - - self.observations.append({ - "screenshot": base64_image, - "accessibility_tree": linearized_accessibility_tree - }) - - messages.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": ACTION_DESCRIPTION_PROMPT_SEEACT.format(linearized_accessibility_tree) + "text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?" }, { "type": "image_url", @@ -475,38 +433,6 @@ class PromptAgent: logger.info("RESPONSE: %s", response) - if self.observation_type == "seeact": - messages.append({ - "role": "assistant", - "content": [ - { - "type": "text", - "text": response - } - ] - }) - - messages.append({ - "role": "user", - "content": [ - { - "type": "text", - "text": "{}\n\nWhat's the next step that you will do to help with the task?".format( - ACTION_GROUNDING_PROMPT_SEEACT) - } - ] - }) - - logger.info("Generating content with GPT model: %s", self.model) - response = self.call_llm({ - "model": self.model, - "messages": messages, - "max_tokens": self.max_tokens, - "top_p": self.top_p, - "temperature": self.temperature - }) - logger.info("RESPONSE: %s", response) - try: actions = self.parse_actions(response, masks) self.thoughts.append(response) @@ -523,12 +449,11 @@ class PromptAgent: # but you are forbidden to add "Exception", that is, a common type of exception # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit (openai.RateLimitError, - openai.BadRequestError, - openai.InternalServerError, - InvalidArgument), + openai.BadRequestError, + openai.InternalServerError, + InvalidArgument), max_tries=5 ) - def call_llm(self, payload): if self.model.startswith("gpt"): @@ -553,7 +478,8 @@ class PromptAgent: json=payload ) if retry_response.status_code != 200: - logger.error("Failed to call LLM even after attempt on shortening the history: " + retry_response.text) + logger.error( + "Failed to call LLM even after attempt on shortening the history: " + retry_response.text) return "" logger.error("Failed to call LLM: " + response.text) @@ -742,7 +668,7 @@ class PromptAgent: self.actions.append(actions) return actions - elif self.observation_type in ["som", "seeact"]: + elif self.observation_type in ["som"]: # parse from the response if self.action_space == "computer_13": raise ValueError("Invalid action space: " + self.action_space) diff --git a/mm_agents/prompts.py b/mm_agents/prompts.py index 15aefeb..462aac7 100644 --- a/mm_agents/prompts.py +++ b/mm_agents/prompts.py @@ -798,10 +798,10 @@ You MUST choose and ONLY CHOOSE from the action space above, otherwise your acti You CAN predict multiple actions at one step, but you should only return one action for each step. """.strip() -SYS_PROMPT_IN_SOM_A11Y_OUT_TAG = """ +SYS_PROMPT_IN_SOM_OUT_TAG = """ You are an agent which follow my instruction and perform desktop computer tasks as instructed. You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. -For each step, you will get an observation of the desktop by 1) a screenshot; and 2) accessibility tree, which is based on AT-SPI library. +For each step, you will get an observation of the desktop by a screenshot with interact-able elements marked with numerical tags. And you will predict the action of the computer based on the image. You are required to use `pyautogui` to perform the action grounded to the observation, but DONOT use the `pyautogui.locateCenterOnScreen` function to locate the element you want to operate with since we have no image of the element you want to operate with. DONOT USE `pyautogui.screenshot()` to make screenshot. You can replace x, y in the code with the tag of the element you want to operate with. such as: From e166106b6ad13251fcb61d328dae67288649e7df Mon Sep 17 00:00:00 2001 From: David Chang Date: Fri, 15 Mar 2024 22:46:14 +0800 Subject: [PATCH 6/8] ver Mar15th added an option to keep buttons without text information but with an image for SoM setting --- desktop_env/server/main.py | 9 +++++++++ .../accessibility_tree_wrap/heuristic_retrieve.py | 14 ++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index 57649ac..efa62c7 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -284,6 +284,15 @@ def _create_atspi_node(node: Accessible, depth: int = 0, flag: Optional[str] = N text = text.replace("\ufffc", "").replace("\ufffd", "") # }}} Text # + # Image {{{ # + try: + node.queryImage() + except NotImplementedError: + pass + else: + attribute_dict["image"] = "true" + # }}} Image # + # Selection {{{ # try: node.querySelection() diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index 34a1d76..191eaa7 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -26,7 +26,7 @@ def find_leaf_nodes(xlm_file_str): state_ns = "uri:deskat:state.at-spi.gnome.org" component_ns = "uri:deskat:component.at-spi.gnome.org" -def judge_node(node: ET, platform="ubuntu") -> bool: +def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: keeps: bool = node.tag.startswith("document")\ or node.tag.endswith("item")\ or node.tag.endswith("button")\ @@ -60,7 +60,9 @@ def judge_node(node: ET, platform="ubuntu") -> bool: or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\ or node.get("{{{:}}}checkable".format(state_ns), "false")=="true" )\ - and (node.get("name", "") != "" or node.text is not None and len(node.text)>0) + and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\ + or check_image and node.get("image", "false")=="true" + ) coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)")) sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)")) @@ -155,12 +157,12 @@ def print_nodes_with_indent(nodes, indent=0): if __name__ == '__main__': import json - with open('4.json', 'r', encoding='utf-8') as f: - xml_file_str = json.load(f)["AT"] + with open('selection_sorted(imaged).xml', 'r', encoding='utf-8') as f: + xml_file_str = f.read() filtered_nodes = filter_nodes(ET.fromstring(xml_file_str)) print(len(filtered_nodes)) - masks = draw_bounding_boxes( filtered_nodes, '4.png' - , '4.a.png' + masks = draw_bounding_boxes( filtered_nodes, 'selection_sorted(imaged).png' + , 'selection_sorted(imaged).ai.png' ) # print(masks) From 57f2257254924f56c2ef85789deba81778ffb69a Mon Sep 17 00:00:00 2001 From: David Chang Date: Fri, 15 Mar 2024 22:49:35 +0800 Subject: [PATCH 7/8] ver Mar15th fixed bugs about infeasible task evaluation --- mm_agents/accessibility_tree_wrap/heuristic_retrieve.py | 4 ++-- mm_agents/agent.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index 191eaa7..e37f614 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -69,11 +69,11 @@ def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0 return keeps -def filter_nodes(root: ET, platform="ubuntu"): +def filter_nodes(root: ET, platform="ubuntu", check_image=False): filtered_nodes = [] for node in root.iter(): - if judge_node(node, platform): + if judge_node(node, platform, check_image): filtered_nodes.append(node) #print(ET.tostring(node, encoding="unicode")) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index f8fc1c8..039eda8 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -65,7 +65,7 @@ def tag_screenshot(screenshot, accessibility_tree): os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) - nodes = filter_nodes(ET.fromstring(accessibility_tree)) + nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True) # Make tag screenshot marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) From 99e86a2cd4c23f181813adcbd453239da4281e95 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 15 Mar 2024 23:12:18 +0800 Subject: [PATCH 8/8] Update unfinished function and error catching --- run.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/run.py b/run.py index c56e142..7118d5b 100644 --- a/run.py +++ b/run.py @@ -6,9 +6,10 @@ import datetime import json import logging import os -import signal import sys +from tqdm import tqdm + from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent @@ -52,7 +53,8 @@ def handler(signo, frame): raise RuntimeError("Time limit exceeded!") -signal.signal(signal.SIGALRM, handler) +# fixme: windows doesn't support signal +# signal.signal(signal.SIGALRM, handler) def config() -> argparse.Namespace: @@ -128,8 +130,8 @@ def test( headless=args.headless, ) - for domain in test_all_meta: - for example_id in test_all_meta[domain]: + for domain in tqdm(test_all_meta, desc="Domain"): + for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: @@ -154,7 +156,7 @@ def test( # example start running try: - signal.alarm(time_limit) + # signal.alarm(time_limit) fixme: windows doesn't support signal agent.reset() obs = env.reset(task_config=example) done = False @@ -204,6 +206,8 @@ def test( result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) except RuntimeError as e: logger.error(f"Error in example {domain}/{example_id}: {e}") @@ -224,6 +228,10 @@ def test( })) f.write("\n") continue + except Exception as e: + logger.error(f"Error in example {domain}/{example_id}: {e}") + continue + env.close() logger.info(f"Average score: {sum(scores) / len(scores)}") @@ -236,9 +244,13 @@ def get_unfinished(action_space, use_model, observation_type, result_dir, total_ finished = {} for domain in os.listdir(target_dir): + finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): - finished[domain] = os.listdir(domain_path) + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path) and "result.txt" in os.listdir(example_path): + finished[domain].append(example_id) if not finished: return total_file_json