Files
sci-gui-agent-benchmark/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py
David Chang e95e8e55ea ver Mar11th
updated filter_nodes
2024-03-11 12:33:47 +08:00

168 lines
7.0 KiB
Python

import xml.etree.ElementTree as ET
from PIL import Image, ImageDraw, ImageFont
from typing import Tuple
def find_leaf_nodes(xlm_file_str):
if not xlm_file_str:
return []
root = ET.fromstring(xlm_file_str)
# Recursive function to traverse the XML tree and collect leaf nodes
def collect_leaf_nodes(node, leaf_nodes):
# If the node has no children, it is a leaf node, add it to the list
if not list(node):
leaf_nodes.append(node)
# If the node has children, recurse on each child
for child in node:
collect_leaf_nodes(child, leaf_nodes)
# List to hold all leaf nodes
leaf_nodes = []
collect_leaf_nodes(root, leaf_nodes)
return leaf_nodes
state_ns = "uri:deskat:state.at-spi.gnome.org"
component_ns = "uri:deskat:component.at-spi.gnome.org"
def judge_node(node: ET, platform="ubuntu") -> bool:
keeps: bool = node.tag.startswith("document")\
or node.tag.endswith("item")\
or node.tag.endswith("button")\
or node.tag.endswith("heading")\
or node.tag.endswith("label")\
or node.tag.endswith("scrollbar")\
or node.tag.endswith("searchbox")\
or node.tag.endswith("textbox")\
or node.tag.endswith("link")\
or node.tag.endswith("tabelement")\
or node.tag.endswith("textfield")\
or node.tag.endswith("textarea")\
or node.tag.endswith("menu")\
or node.tag in [ "alert", "canvas", "check-box"
, "combo-box", "entry", "icon"
, "image", "paragraph", "scroll-bar"
, "section", "slider", "static"
, "table-cell", "terminal", "text"
, "netuiribbontab", "start", "trayclockwclass"
, "traydummysearchcontrol", "uiimage", "uiproperty"
, "uiribboncommandbar"
]
keeps = keeps and ( platform=="ubuntu"\
and node.get("{{{:}}}showing".format(state_ns), "false")=="true"\
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
or platform=="windows"\
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
)\
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
)\
and (node.get("name", "") != "" or node.text is not None and len(node.text)>0)
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
keeps = keeps and coordinates[0]>0 and coordinates[1]>0 and sizes[0]>0 and sizes[1]>0
return keeps
def filter_nodes(root: ET, platform="ubuntu"):
filtered_nodes = []
for node in root.iter():
if judge_node(node, platform):
filtered_nodes.append(node)
#print(ET.tostring(node, encoding="unicode"))
return filtered_nodes
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
# Load the screenshot image
image = Image.open(image_file_path)
draw = ImageDraw.Draw(image)
marks = []
drew_nodes = []
try:
# Adjust the path to the font file you have or use a default one
font = ImageFont.truetype("arial.ttf", 15)
except IOError:
# Fallback to a basic font if the specified font can't be loaded
font = ImageFont.load_default()
index = 1
# Loop over all the visible nodes and draw their bounding boxes
for _node in nodes:
coords_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord')
size_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size')
if coords_str and size_str:
try:
# Parse the coordinates and size from the strings
coords = tuple(map(int, coords_str.strip('()').split(', ')))
size = tuple(map(int, size_str.strip('()').split(', ')))
# Check for negative sizes
if size[0] <= 0 or size[1] <= 0:
raise ValueError(f"Size must be positive, got: {size}")
# Calculate the bottom-right corner of the bounding box
bottom_right = (coords[0] + size[0], coords[1] + size[1])
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
# Check if the area only contains one color
cropped_image = image.crop((*coords, *bottom_right))
if len(set(list(cropped_image.getdata()))) == 1:
continue
# Draw rectangle on image
draw.rectangle([coords, bottom_right], outline="red", width=1)
# Draw index number at the bottom left of the bounding box with black background
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
text_bbox: Tuple[int, int ,int ,int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
#offset: int = bottom_right[1]-text_bbox[3]
#text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
#draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
draw.rectangle(text_bbox, fill='black')
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
index += 1
# each mark is an x, y, w, h tuple
marks.append([coords[0], coords[1], size[0], size[1]])
drew_nodes.append(_node)
except ValueError:
pass
# Save the result
image.save(output_image_file_path)
return marks, drew_nodes
def print_nodes_with_indent(nodes, indent=0):
for node in nodes:
print(' ' * indent, node.tag, node.attrib)
print_nodes_with_indent(node, indent + 2)
if __name__ == '__main__':
import json
with open('4.json', 'r', encoding='utf-8') as f:
xml_file_str = json.load(f)["AT"]
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
print(len(filtered_nodes))
masks = draw_bounding_boxes( filtered_nodes, '4.png'
, '4.a.png'
)
# print(masks)
print(len(masks))