From 3ce7636abd110d1d23af10cc1d440fc085379e70 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Thu, 21 Mar 2024 22:05:16 +0800 Subject: [PATCH 1/8] Fix one multi_app example; remove some broken examples; Support downsampling --- .../7f35355e-02a6-45b5-b140-f0be698bcf85.json | 4 +- evaluation_examples/test_all.json | 6 --- evaluation_examples/test_small.json | 43 +------------------ .../heuristic_retrieve.py | 9 +++- mm_agents/agent.py | 11 ----- 5 files changed, 11 insertions(+), 62 deletions(-) diff --git a/evaluation_examples/examples/multi_apps/7f35355e-02a6-45b5-b140-f0be698bcf85.json b/evaluation_examples/examples/multi_apps/7f35355e-02a6-45b5-b140-f0be698bcf85.json index c33b042..f9161d7 100644 --- a/evaluation_examples/examples/multi_apps/7f35355e-02a6-45b5-b140-f0be698bcf85.json +++ b/evaluation_examples/examples/multi_apps/7f35355e-02a6-45b5-b140-f0be698bcf85.json @@ -9,7 +9,7 @@ "parameters": { "files": [ { - "url": "https://docs.google.com/spreadsheets/d/13YL-KC__pav2qp3sFDs1BT2wZnpWGp7s/export?format=xlsx", + "url": "https://drive.google.com/uc?export=download&id=1B5GmhdVD07UeYj9Ox20DHsA_gaxEFQ6Z", "path": "/home/user/Desktop/stock.xlsx" } ] @@ -36,7 +36,7 @@ }, "expected": { "type": "cloud_file", - "path": "https://drive.google.com/uc?export=download&id=1oPPW_dozWGII5MRmdXdKKoEK5iBkd_8Q", + "path": "https://drive.google.com/uc?export=download&id=1wzlUL1gktA0d_j9W3WSSAAUcuKr5gw-n", "dest": "result_gold.txt" } } diff --git a/evaluation_examples/test_all.json b/evaluation_examples/test_all.json index e530435..798e858 100644 --- a/evaluation_examples/test_all.json +++ b/evaluation_examples/test_all.json @@ -304,15 +304,12 @@ "os": [ "94d95f96-9699-4208-98ba-3c3119edf9c2", "bedcedc4-4d72-425e-ad62-21960b11fe0d", - "43c2d64c-bab5-4dcb-a30c-b888321c319a", - "7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82", "ec4e3f68-9ea4-4c18-a5c9-69f89d1178b3", "a462a795-fdc7-4b23-b689-e8b6df786b78", "f9be0997-4b7c-45c5-b05c-4612b44a6118", "28cc3b7e-b194-4bc9-8353-d04c0f4d56d2", "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", "e0df059f-28a6-4169-924f-b9623e7184cc", - "ddc75b62-7311-4af8-bfb3-859558542b36", "b6781586-6346-41cd-935a-a6b1487918fc", "b3d4a89c-53f2-4d6b-8b6a-541fb5d205fa", "3ce045a0-877b-42aa-8d2c-b4a863336ab8", @@ -322,8 +319,6 @@ "23393935-50c7-4a86-aeea-2b78fd089c5c", "5812b315-e7bd-4265-b51f-863c02174c28", "c288e301-e626-4b98-a1ab-159dcb162af5", - "cc9d4f34-1ca0-4a1b-8ff2-09302696acb9", - "c56de254-a3ec-414e-81a6-83d2ce8c41fa", "4783cc41-c03c-4e1b-89b4-50658f642bd5", "5c1075ca-bb34-46a3-a7a0-029bd7463e79", "5ced85fc-fa1a-4217-95fd-0fb530545ce2", @@ -376,7 +371,6 @@ "4e60007a-f5be-4bfc-9723-c39affa0a6d3", "e2b5e914-ffe1-44d2-8e92-58f8c5d92bb2", "9439a27b-18ae-42d8-9778-5f68f891805e", - "ae506c68-352c-4094-9caa-ee9d42052317", "ea98c5d7-3cf9-4f9b-8ad3-366b58e0fcae", "930fdb3b-11a8-46fe-9bac-577332e2640e", "276cc624-87ea-4f08-ab93-f770e3790175", diff --git a/evaluation_examples/test_small.json b/evaluation_examples/test_small.json index aec99fc..7a072f5 100644 --- a/evaluation_examples/test_small.json +++ b/evaluation_examples/test_small.json @@ -37,48 +37,7 @@ "eb303e01-261e-4972-8c07-c9b4e7a4922a", "d1acdb87-bb67-4f30-84aa-990e56a09c92", "deec51c9-3b1e-4b9e-993c-4776f20e8bb2", - "8e116af7-7db7-4e35-a68b-b0939c066c78", - "185f29bd-5da0-40a6-b69c-ba7f4e0324ef", - "2c1ebcd7-9c6d-4c9a-afad-900e381ecd5e", - "3a93cae4-ad3e-403e-8c12-65303b271818", - "1f18aa87-af6f-41ef-9853-cdb8f32ebdea", - "26150609-0da3-4a7d-8868-0faf9c5f01bb", - "7e287123-70ca-47b9-8521-47db09b69b14", - "e2392362-125e-4f76-a2ee-524b183a3412", - "26660ad1-6ebb-4f59-8cba-a8432dfe8d38", - "a82b78bb-7fde-4cb3-94a4-035baf10bcf0", - "36037439-2044-4b50-b9d1-875b5a332143", - "716a6079-22da-47f1-ba73-c9d58f986a38", - "a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a", - "6f4073b8-d8ea-4ade-8a18-c5d1d5d5aa9a", - "da922383-bfa4-4cd3-bbad-6bebab3d7742", - "2373b66a-092d-44cb-bfd7-82e86e7a3b4d", - "81c425f5-78f3-4771-afd6-3d2973825947", - "227d2f97-562b-4ccb-ae47-a5ec9e142fbb", - "20236825-b5df-46e7-89bf-62e1d640a897", - "02ce9a50-7af2-47ed-8596-af0c230501f8", - "4c26e3f3-3a14-4d86-b44a-d3cedebbb487", - "09a37c51-e625-49f4-a514-20a773797a8a", - "3e3fc409-bff3-4905-bf16-c968eee3f807", - "415ef462-bed3-493a-ac36-ca8c6d23bf1b", - "9f3bb592-209d-43bc-bb47-d77d9df56504", - "dd60633f-2c72-42ba-8547-6f2c8cb0fdb0", - "3f05f3b9-29ba-4b6b-95aa-2204697ffc06", - "f8369178-fafe-40c2-adc4-b9b08a125456", - "778efd0a-153f-4842-9214-f05fc176b877", - "47f7c0ce-a5fb-4100-a5e6-65cd0e7429e5", - "c2751594-0cd5-4088-be1b-b5f2f9ec97c4", - "48c46dc7-fe04-4505-ade7-723cba1aa6f6", - "42d25c08-fb87-4927-8b65-93631280a26f", - "3c8f201a-009d-4bbe-8b65-a6f8b35bb57f", - "d68204bf-11c1-4b13-b48b-d303c73d4bf6", - "91190194-f406-4cd6-b3f9-c43fac942b22", - "7f35355e-02a6-45b5-b140-f0be698bcf85", - "98e8e339-5f91-4ed2-b2b2-12647cb134f4", - "df67aebb-fb3a-44fd-b75b-51b6012df509", - "5df7b33a-9f77-4101-823e-02f863e1c1ae", - "22a4636f-8179-4357-8e87-d1743ece1f81", - "236833a3-5704-47fc-888c-4f298f09f799" + "8e116af7-7db7-4e35-a68b-b0939c066c78" ], "os": [ "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index 934d8fd..e2845f3 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -80,9 +80,11 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False): return filtered_nodes -def draw_bounding_boxes(nodes, image_file_path, output_image_file_path): +def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0): # Load the screenshot image image = Image.open(image_file_path) + if float(down_sampling_ratio) != 1.0: + image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio))) draw = ImageDraw.Draw(image) marks = [] drew_nodes = [] @@ -108,6 +110,11 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path): coords = tuple(map(int, coords_str.strip('()').split(', '))) size = tuple(map(int, size_str.strip('()').split(', '))) + if float(down_sampling_ratio) != 1.0: + # Downsample the coordinates and size + coords = tuple(int(coord * down_sampling_ratio) for coord in coords) + size = tuple(int(s * down_sampling_ratio) for s in size) + # Check for negative sizes if size[0] <= 0 or size[1] <= 0: raise ValueError(f"Size must be positive, got: {size}") diff --git a/mm_agents/agent.py b/mm_agents/agent.py index d7a5586..4600628 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -6,17 +6,14 @@ import re import time import uuid import xml.etree.ElementTree as ET -import numpy as np from http import HTTPStatus from io import BytesIO from typing import Dict, List, Tuple, Union - import backoff import dashscope import google.generativeai as genai import openai import requests -import cv2 from PIL import Image from google.api_core.exceptions import InvalidArgument @@ -28,14 +25,6 @@ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_S logger = logging.getLogger("desktopenv.agent") -def downsample_image(img: Union[str, np.ndarray], ratio: Tuple[float, float]): - fx, fy = ratio - if isinstance(img, str): - img = cv2.imread(img) - - resized = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA) - return resized - # Function to encode the image def encode_image(image_path): From c34d1b37a5024172d5e5bea546261032c605dd02 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Thu, 21 Mar 2024 22:38:02 +0800 Subject: [PATCH 2/8] Update small_test set --- evaluation_examples/test_small.json | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/evaluation_examples/test_small.json b/evaluation_examples/test_small.json index 7a072f5..6598912 100644 --- a/evaluation_examples/test_small.json +++ b/evaluation_examples/test_small.json @@ -41,9 +41,7 @@ ], "os": [ "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", - "5812b315-e7bd-4265-b51f-863c02174c28", - "43c2d64c-bab5-4dcb-a30c-b888321c319a", - "7688b85f-87a4-4e4a-b2f8-f3d6c3f29b82" + "5812b315-e7bd-4265-b51f-863c02174c28" ], "thunderbird": [ "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", From d4e81afae717a166ccfc96ff69761244800f8958 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Thu, 21 Mar 2024 23:06:27 +0800 Subject: [PATCH 3/8] Update small_test set --- evaluation_examples/test_small.json | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/evaluation_examples/test_small.json b/evaluation_examples/test_small.json index 6598912..b1047a2 100644 --- a/evaluation_examples/test_small.json +++ b/evaluation_examples/test_small.json @@ -1,7 +1,9 @@ { "chrome": [ "bb5e4c0d-f964-439c-97b6-bdb9747de3f4", - "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3" + "7b6c7e24-c58a-49fc-a5bb-d57b80e5b4c3", + "35253b65-1c19-4304-8aa4-6884b8218fc0", + "a96b564e-dbe9-42c3-9ccf-b4498073438a" ], "gimp": [ "7a4deb26-d57d-4ea9-9a73-630f66a7b568", @@ -9,7 +11,8 @@ ], "libreoffice_calc": [ "357ef137-7eeb-4c80-a3bb-0951f26a8aff", - "42e0a640-4f19-4b28-973d-729602b5a4a7" + "42e0a640-4f19-4b28-973d-729602b5a4a7", + "abed40dc-063f-4598-8ba5-9fe749c0615d" ], "libreoffice_impress": [ "5d901039-a89c-4bfb-967b-bf66f4df075e", @@ -20,24 +23,23 @@ "0a0faba3-5580-44df-965d-f562a99b291c" ], "multi_apps": [ + "a74b607e-6bb5-4ea8-8a7c-5d97c7bbcd2a", + "5990457f-2adb-467b-a4af-5c857c92d762", "2b9493d7-49b8-493a-a71b-56cd1f4d6908", "46407397-a7d5-4c6b-92c6-dbe038b1457b", "4e9f0faf-2ecc-4ae8-a804-28c9a75d1ddc", "510f64c8-9bcc-4be1-8d30-638705850618", "897e3b53-5d4d-444b-85cb-2cdc8a97d903", "c867c42d-a52d-4a24-8ae3-f75d256b5618", - "e135df7c-7687-4ac0-a5f0-76b74438b53e", - "f7dfbef3-7697-431c-883a-db8583a4e4f9", - "6d72aad6-187a-4392-a4c4-ed87269c51cf", - "f918266a-b3e0-4914-865d-4faa564f1aef", - "da52d699-e8d2-4dc5-9191-a2199e0b6a9b", "74d5859f-ed66-4d3e-aa0e-93d7a592ce41", "b5062e3e-641c-4e3a-907b-ac864d2e7652", "48d05431-6cd5-4e76-82eb-12b60d823f7d", "eb303e01-261e-4972-8c07-c9b4e7a4922a", "d1acdb87-bb67-4f30-84aa-990e56a09c92", "deec51c9-3b1e-4b9e-993c-4776f20e8bb2", - "8e116af7-7db7-4e35-a68b-b0939c066c78" + "8e116af7-7db7-4e35-a68b-b0939c066c78", + "716a6079-22da-47f1-ba73-c9d58f986a38", + "2373b66a-092d-44cb-bfd7-82e86e7a3b4d" ], "os": [ "5ea617a3-0e86-4ba6-aab2-dac9aa2e8d57", @@ -53,6 +55,7 @@ ], "vs_code": [ "0ed39f63-6049-43d4-ba4d-5fa2fe04a951", - "53ad5833-3455-407b-bbc6-45b4c79ab8fb" + "53ad5833-3455-407b-bbc6-45b4c79ab8fb", + "276cc624-87ea-4f08-ab93-f770e3790175" ] } \ No newline at end of file From 5f2802292acae04f3e0b14578e0d578dc5b77166 Mon Sep 17 00:00:00 2001 From: Yiheng Xu Date: Fri, 22 Mar 2024 12:54:22 +0800 Subject: [PATCH 4/8] Update agent.py --- mm_agents/agent.py | 38 +++++++++++++++++++++----------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 4600628..4b27968 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -741,26 +741,30 @@ class PromptAgent: assert api_key is not None, "Please set the GENAI_API_KEY environment variable" genai.configure(api_key=api_key) logger.info("Generating content with Gemini model: %s", self.model) - response = genai.GenerativeModel(self.model).generate_content( - gemini_messages, - generation_config={ - "candidate_count": 1, - "max_output_tokens": max_tokens, - "top_p": top_p, - "temperature": temperature - }, - safety_settings={ - "harassment": "block_none", - "hate": "block_none", - "sex": "block_none", - "danger": "block_none" - } - ) - + request_options = {"timeout": 120} + gemini_model = genai.GenerativeModel(self.model) try: + response = gemini_model.generate_content( + gemini_messages, + generation_config={ + "candidate_count": 1, + "max_output_tokens": max_tokens, + "top_p": top_p, + "temperature": temperature + }, + safety_settings={ + "harassment": "block_none", + "hate": "block_none", + "sex": "block_none", + "danger": "block_none" + }, + request_options=request_options + ) return response.text except Exception as e: - logger.error("Meet exception when calling Gemini API, " + str(e)) + logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e)) + logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}") + logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}") return "" elif self.model.startswith("qwen"): messages = payload["messages"] From 0aab0e3745c8f7f2b29857d501fa0307e98f0ff8 Mon Sep 17 00:00:00 2001 From: Fangyu Lei <55661995+lfy79001@users.noreply.github.com> Date: Mon, 25 Mar 2024 15:11:25 +0800 Subject: [PATCH 5/8] Update requirements.txt --- desktop_env/server/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/desktop_env/server/requirements.txt b/desktop_env/server/requirements.txt index 546da44..1b5ecda 100644 --- a/desktop_env/server/requirements.txt +++ b/desktop_env/server/requirements.txt @@ -6,3 +6,4 @@ requests flask numpy lxml +pygame From 635b6717b34a2a33835e7599dccb177a45a31c9a Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Mon, 25 Mar 2024 17:55:28 +0800 Subject: [PATCH 6/8] Fix a key error in multiapps --- .../multi_apps/510f64c8-9bcc-4be1-8d30-638705850618.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation_examples/examples/multi_apps/510f64c8-9bcc-4be1-8d30-638705850618.json b/evaluation_examples/examples/multi_apps/510f64c8-9bcc-4be1-8d30-638705850618.json index de6e9f4..cb4a198 100644 --- a/evaluation_examples/examples/multi_apps/510f64c8-9bcc-4be1-8d30-638705850618.json +++ b/evaluation_examples/examples/multi_apps/510f64c8-9bcc-4be1-8d30-638705850618.json @@ -108,7 +108,7 @@ { "type": "rule", "rules": { - "expect": "project" + "expected": "project" } } ] From 172123ab2c229a57638d8c3b4b4397f5ce3f05fa Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Mon, 25 Mar 2024 18:02:48 +0800 Subject: [PATCH 7/8] Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim --- desktop_env/envs/desktop_env.py | 8 +- desktop_env/server/main.py | 40 ++++++-- .../heuristic_retrieve.py | 6 +- mm_agents/agent.py | 93 ++++++++++++------- 4 files changed, 104 insertions(+), 43 deletions(-) diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index b443a4a..5fd972d 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -146,7 +146,13 @@ class DesktopEnv(gym.Env): image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no)) # Get the screenshot and save to the image_path - screenshot = self.controller.get_screenshot() + max_retries = 20 + for _ in range(max_retries): + screenshot = self.controller.get_screenshot() + if screenshot is not None: + break + time.sleep(1) + with open(image_path, "wb") as f: f.write(screenshot) diff --git a/desktop_env/server/main.py b/desktop_env/server/main.py index 8e900a3..cd6998c 100644 --- a/desktop_env/server/main.py +++ b/desktop_env/server/main.py @@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str # Value {{{ # if hasattr(node, "get_step"): - attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step()) + try: + attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step()) + except: + pass if hasattr(node, "value"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value()) + except: + pass if hasattr(node, "get_value"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value()) + except: + pass elif hasattr(node, "get_position"): - attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position()) + try: + attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position()) + except: + pass if hasattr(node, "min_value"): - attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value()) + try: + attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value()) + except: + pass elif hasattr(node, "get_range_min"): - attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min()) + try: + attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min()) + except: + pass if hasattr(node, "max_value"): - attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value()) + try: + attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value()) + except: + pass elif hasattr(node, "get_range_max"): - attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max()) + try: + attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max()) + except: + pass # }}} Value # attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node)) diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index e2845f3..5c7b830 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -110,6 +110,10 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam coords = tuple(map(int, coords_str.strip('()').split(', '))) size = tuple(map(int, size_str.strip('()').split(', '))) + import copy + original_coords = copy.deepcopy(coords) + original_size = copy.deepcopy(size) + if float(down_sampling_ratio) != 1.0: # Downsample the coordinates and size coords = tuple(int(coord * down_sampling_ratio) for coord in coords) @@ -145,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam draw.text(text_position, str(index), font=font, anchor="lb", fill="white") # each mark is an x, y, w, h tuple - marks.append([coords[0], coords[1], size[0], size[1]]) + marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]]) drew_nodes.append(_node) if _node.text: diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 4b27968..5ed3d27 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -8,12 +8,14 @@ import uuid import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO -from typing import Dict, List, Tuple, Union +from typing import Dict, List + import backoff import dashscope import google.generativeai as genai import openai import requests +import tiktoken from PIL import Image from google.api_core.exceptions import InvalidArgument @@ -32,49 +34,49 @@ def encode_image(image_path): return base64.b64encode(image_file.read()).decode('utf-8') -def linearize_accessibility_tree(accessibility_tree): +def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): # leaf_nodes = find_leaf_nodes(accessibility_tree) - filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) + filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"] # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: - #linearized_accessibility_tree += node.tag + "\t" - #linearized_accessibility_tree += node.attrib.get('name') + "\t" + # linearized_accessibility_tree += node.tag + "\t" + # linearized_accessibility_tree += node.attrib.get('name') + "\t" if node.text: - text = ( node.text if '"' not in node.text\ - else '"{:}"'.format(node.text.replace('"', '""')) - ) + text = (node.text if '"' not in node.text \ + else '"{:}"'.format(node.text.replace('"', '""')) + ) elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ and node.get("{uri:deskat:value.at-spi.gnome.org}value"): text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value") - text = (text if '"' not in text\ - else '"{:}"'.format(text.replace('"', '""')) - ) + text = (text if '"' not in text \ + else '"{:}"'.format(text.replace('"', '""')) + ) else: text = '""' - #linearized_accessibility_tree += node.attrib.get( - #, "") + "\t" - #linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" + # linearized_accessibility_tree += node.attrib.get( + # , "") + "\t" + # linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n" linearized_accessibility_tree.append( - "{:}\t{:}\t{:}\t{:}\t{:}".format( - node.tag, node.get("name", ""), text - , node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "") - , node.get('{uri:deskat:component.at-spi.gnome.org}size', "") - ) - ) + "{:}\t{:}\t{:}\t{:}\t{:}".format( + node.tag, node.get("name", ""), text + , node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "") + , node.get('{uri:deskat:component.at-spi.gnome.org}size', "") + ) + ) return "\n".join(linearized_accessibility_tree) -def tag_screenshot(screenshot, accessibility_tree): +def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"): # Creat a tmp file to store the screenshot in random name uuid_str = str(uuid.uuid4()) os.makedirs("tmp/images", exist_ok=True) tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) - nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True) + nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True) # Make tag screenshot marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) @@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks): return actions +def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): + enc = tiktoken.encoding_for_model("gpt-4") + tokens = enc.encode(linearized_accessibility_tree) + if len(tokens) > max_tokens: + linearized_accessibility_tree = enc.decode(tokens[:max_tokens]) + linearized_accessibility_tree += "[...]\n" + return linearized_accessibility_tree + class PromptAgent: def __init__( self, + platform="ubuntu", model="gpt-4-vision-preview", max_tokens=1500, top_p=0.9, @@ -180,8 +191,10 @@ class PromptAgent: action_space="computer_13", observation_type="screenshot_a11y_tree", # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] - max_trajectory_length=3 + max_trajectory_length=3, + a11y_tree_max_tokens=10000 ): + self.platform = platform self.model = model self.max_tokens = max_tokens self.top_p = top_p @@ -189,6 +202,7 @@ class PromptAgent: self.action_space = action_space self.observation_type = observation_type self.max_trajectory_length = max_trajectory_length + self.a11y_tree_max_tokens = a11y_tree_max_tokens self.thoughts = [] self.actions = [] @@ -349,9 +363,14 @@ class PromptAgent: # {{{1 if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: base64_image = encode_image(obs["screenshot"]) - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"], + platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + if self.observation_type == "screenshot_a11y_tree": self.observations.append({ "screenshot": base64_image, @@ -383,9 +402,14 @@ class PromptAgent: ] }) elif self.observation_type == "a11y_tree": - linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) + linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"], + platform=self.platform) logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + self.observations.append({ "screenshot": None, "accessibility_tree": linearized_accessibility_tree @@ -403,10 +427,15 @@ class PromptAgent: }) elif self.observation_type == "som": # Add som to the screenshot - masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"]) + masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[ + "accessibility_tree"], self.platform) base64_image = encode_image(tagged_screenshot) logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree, + self.a11y_tree_max_tokens) + self.observations.append({ "screenshot": base64_image, "accessibility_tree": linearized_accessibility_tree @@ -435,7 +464,7 @@ class PromptAgent: # with open("messages.json", "w") as f: # f.write(json.dumps(messages, indent=4)) - #logger.info("PROMPT: %s", messages) + # logger.info("PROMPT: %s", messages) response = self.call_llm({ "model": self.model, @@ -556,8 +585,6 @@ class PromptAgent: "Content-Type": "application/json" } - - payload = { "model": self.model, "max_tokens": max_tokens, @@ -570,7 +597,8 @@ class PromptAgent: attempt = 0 while attempt < max_attempts: # response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload) - response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload) + response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, + json=payload) if response.status_code == 200: result = response.json()['choices'][0]['message']['content'] break @@ -581,7 +609,7 @@ class PromptAgent: else: print("Exceeded maximum attempts to call LLM.") result = "" - + return result @@ -605,14 +633,13 @@ class PromptAgent: mistral_messages.append(mistral_message) - from openai import OpenAI client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"], base_url='https://api.together.xyz', ) logger.info("Generating content with Mistral model: %s", self.model) - + flag = 0 while True: try: From 607cf8e55475cea5f299b75e6c346c6e6c96cbb2 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Mon, 25 Mar 2024 18:09:43 +0800 Subject: [PATCH 8/8] Fix max traj length --- mm_agents/agent.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index 5ed3d27..a4b1b3a 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -264,9 +264,14 @@ class PromptAgent: , "The number of observations and actions should be the same." if len(self.observations) > self.max_trajectory_length: - _observations = self.observations[-self.max_trajectory_length:] - _actions = self.actions[-self.max_trajectory_length:] - _thoughts = self.thoughts[-self.max_trajectory_length:] + if self.max_trajectory_length == 0: + _observations = [] + _actions = [] + _thoughts = [] + else: + _observations = self.observations[-self.max_trajectory_length:] + _actions = self.actions[-self.max_trajectory_length:] + _thoughts = self.thoughts[-self.max_trajectory_length:] else: _observations = self.observations _actions = self.actions