diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index 7c45ec3..337b402 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -24,43 +24,56 @@ def find_leaf_nodes(xlm_file_str): collect_leaf_nodes(root, leaf_nodes) return leaf_nodes +state_ns = "uri:deskat:state.at-spi.gnome.org" +component_ns = "uri:deskat:component.at-spi.gnome.org" +def judge_node(node: ET, platform="ubuntu") -> bool: + keeps: bool = node.tag.startswith("document")\ + or node.tag.endswith("item")\ + or node.tag.endswith("button")\ + or node.tag.endswith("heading")\ + or node.tag.endswith("label")\ + or node.tag.endswith("scrollbar")\ + or node.tag.endswith("searchbox")\ + or node.tag.endswith("textbox")\ + or node.tag.endswith("link")\ + or node.tag.endswith("tabelement")\ + or node.tag.endswith("textfield")\ + or node.tag.endswith("textarea")\ + or node.tag.endswith("menu")\ + or node.tag in [ "alert", "canvas", "check-box" + , "combo-box", "entry", "icon" + , "image", "paragraph", "scroll-bar" + , "section", "slider", "static" + , "table-cell", "terminal", "text" + , "netuiribbontab", "start", "trayclockwclass" + , "traydummysearchcontrol", "uiimage", "uiproperty" + , "uiribboncommandbar" + ] + keeps = keeps and ( platform=="ubuntu"\ + and node.get("{{{:}}}showing".format(state_ns), "false")=="true"\ + and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\ + or platform=="windows"\ + and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\ + )\ + and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\ + or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\ + or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\ + or node.get("{{{:}}}checkable".format(state_ns), "false")=="true" + )\ + and (node.get("name", "") != "" or node.text is not None and len(node.text)>0) -def filter_nodes(nodes, platform="ubuntu"): + coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)")) + sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)")) + keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0 + return keeps + +def filter_nodes(root: ET, platform="ubuntu"): filtered_nodes = [] - for node in nodes: - if node.tag.startswith("document")\ - or node.tag.endswith("item")\ - or node.tag.endswith("button")\ - or node.tag.endswith("heading")\ - or node.tag.endswith("label")\ - or node.tag.endswith("bar")\ - or node.tag.endswith("searchbox")\ - or node.tag.endswith("textbox")\ - or node.tag.endswith("link")\ - or node.tag.endswith("tabelement")\ - or node.tag.endswith("textfield")\ - or node.tag.endswith("textarea")\ - or node.tag in [ "alert", "canvas", "check-box" - , "combo-box", "entry", "icon" - , "image", "paragraph" - , "section", "slider", "static" - , "table-cell", "terminal", "text" - , "netuiribbontab", "start", "trayclockwclass" - , "traydummysearchcontrol", "uiimage", "uiproperty" - ]: - if ( platform=="ubuntu"\ - and node.get("{{{:}}}showing".format("uri:deskat:state.at-spi.gnome.org"), "false")=="true"\ - and node.get("{{{:}}}visible".format("uri:deskat:state.at-spi.gnome.org"), "false")=="true"\ - or platform=="windows"\ - and node.get("{{{:}}}visible".format("uri:deskat:state.at-spi.gnome.org"), "false")=="true"\ - )\ - and node.get("{{{:}}}enabled".format("uri:deskat:state.at-spi.gnome.org"), "false")=="true"\ - and (node.get("name", "") != "" or node.text is not None and len(node.text)>0): - coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format("uri:deskat:component.at-spi.gnome.org"))) - sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format("uri:deskat:component.at-spi.gnome.org"))) - if coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0: - filtered_nodes.append(node) + for node in root.iter(): + if judge_node(node, platform): + filtered_nodes.append(node) + #print(ET.tostring(node, encoding="unicode")) return filtered_nodes @@ -142,12 +155,12 @@ def print_nodes_with_indent(nodes, indent=0): if __name__ == '__main__': import json - with open('2.json', 'r', encoding='utf-8') as f: + with open('4.json', 'r', encoding='utf-8') as f: xml_file_str = json.load(f)["AT"] - filtered_nodes = filter_nodes(find_leaf_nodes(xml_file_str)) + filtered_nodes = filter_nodes(ET.fromstring(xml_file_str)) print(len(filtered_nodes)) - masks = draw_bounding_boxes( filtered_nodes, '2.png' - , '2.a.png' + masks = draw_bounding_boxes( filtered_nodes, '4.png' + , '4.a.png' ) # print(masks) diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/gpt_4v_agent.py index 3c79577..0c6c63e 100644 --- a/mm_agents/gpt_4v_agent.py +++ b/mm_agents/gpt_4v_agent.py @@ -8,6 +8,7 @@ import uuid from http import HTTPStatus from io import BytesIO from typing import Dict, List +import xml.etree.ElementTree as ET import backoff import dashscope @@ -40,8 +41,8 @@ def encode_image(image_path): def linearize_accessibility_tree(accessibility_tree): - leaf_nodes = find_leaf_nodes(accessibility_tree) - filtered_nodes = filter_nodes(leaf_nodes) + #leaf_nodes = find_leaf_nodes(accessibility_tree) + filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree)) linearized_accessibility_tree = "tag\tname\ttext\tposition\tsize\n" # Linearize the accessibility tree nodes into a table format