Modify the namespace of a11y tree (#62)

This commit is contained in:
Tianbao Xie
2024-07-25 20:20:34 +08:00
committed by GitHub
parent 1fd8b66fde
commit a156f8a3d6
3 changed files with 118 additions and 67 deletions

View File

@@ -24,9 +24,21 @@ logger.setLevel(logging.INFO)
MAX_RETRY_TIMES = 10 MAX_RETRY_TIMES = 10
RETRY_INTERVAL = 5 RETRY_INTERVAL = 5
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip" UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip"
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip" UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
DOWNLOADED_FILE_NAME = "Ubuntu.zip"
# Determine the platform and CPU architecture to decide the correct VM image to download
if platform.system() == 'Darwin': # macOS
# if os.uname().machine == 'arm64': # Apple Silicon
URL = UBUNTU_ARM_URL
# else:
# url = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
URL = UBUNTU_X86_URL
else:
raise Exception("Unsupported platform or architecture")
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
REGISTRY_PATH = '.vmware_vms' REGISTRY_PATH = '.vmware_vms'
LOCK_FILE_NAME = '.vmware_lck' LOCK_FILE_NAME = '.vmware_lck'
VMS_DIR = "./vmware_vm_data" VMS_DIR = "./vmware_vm_data"
@@ -109,17 +121,6 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
os.makedirs(vms_dir, exist_ok=True) os.makedirs(vms_dir, exist_ok=True)
def __download_and_unzip_vm(): def __download_and_unzip_vm():
# Determine the platform and CPU architecture to decide the correct VM image to download
if platform.system() == 'Darwin': # macOS
# if os.uname().machine == 'arm64': # Apple Silicon
url = UBUNTU_ARM_URL
# else:
# url = UBUNTU_X86_URL
elif platform.machine().lower() in ['amd64', 'x86_64']:
url = UBUNTU_X86_URL
else:
raise Exception("Unsupported platform or architecture")
# Download the virtual machine image # Download the virtual machine image
logger.info("Downloading the virtual machine image...") logger.info("Downloading the virtual machine image...")
downloaded_size = 0 downloaded_size = 0
@@ -131,7 +132,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
downloaded_size = os.path.getsize(downloaded_file_path) downloaded_size = os.path.getsize(downloaded_file_path)
headers["Range"] = f"bytes={downloaded_size}-" headers["Range"] = f"bytes={downloaded_size}-"
with requests.get(url, headers=headers, stream=True) as response: with requests.get(URL, headers=headers, stream=True) as response:
if response.status_code == 416: if response.status_code == 416:
# This means the range was not satisfiable, possibly the file was fully downloaded # This means the range was not satisfiable, possibly the file was fully downloaded
logger.info("Fully downloaded or the file size changed.") logger.info("Fully downloaded or the file size changed.")

View File

@@ -26,11 +26,25 @@ def find_leaf_nodes(xlm_file_str):
return leaf_nodes return leaf_nodes
state_ns = "uri:deskat:state.at-spi.gnome.org" state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
component_ns = "uri:deskat:component.at-spi.gnome.org" state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
if platform == "ubuntu":
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
elif platform == "windows":
_state_ns = state_ns_windows
_component_ns = component_ns_windows
else:
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
keeps: bool = node.tag.startswith("document") \ keeps: bool = node.tag.startswith("document") \
or node.tag.endswith("item") \ or node.tag.endswith("item") \
or node.tag.endswith("button") \ or node.tag.endswith("button") \
@@ -53,23 +67,26 @@ def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
, "traydummysearchcontrol", "uiimage", "uiproperty" , "traydummysearchcontrol", "uiimage", "uiproperty"
, "uiribboncommandbar" , "uiribboncommandbar"
} }
keeps = keeps and (platform == "ubuntu" \ keeps = keeps and (
and node.get("{{{:}}}showing".format(state_ns), "false") == "true" \ platform == "ubuntu"
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \ and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
or platform == "windows" \ and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \ or platform == "windows"
) \ and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
and (node.get("{{{:}}}enabled".format(state_ns), "false") == "true" \ ) \
or node.get("{{{:}}}editable".format(state_ns), "false") == "true" \ and (
or node.get("{{{:}}}expandable".format(state_ns), "false") == "true" \ node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
or node.get("{{{:}}}checkable".format(state_ns), "false") == "true" or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
) \ or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
and (node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \ or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
or check_image and node.get("image", "false") == "true" ) \
) and (
node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \
or check_image and node.get("image", "false") == "true"
)
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)")) coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)"))
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)")) sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0 keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
return keeps return keeps
@@ -85,7 +102,19 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
return filtered_nodes return filtered_nodes
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0): def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="ubuntu"):
if platform == "ubuntu":
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
_value_ns = value_ns_ubuntu
elif platform == "windows":
_state_ns = state_ns_windows
_component_ns = component_ns_windows
_value_ns = value_ns_windows
else:
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
# Load the screenshot image # Load the screenshot image
image_stream = io.BytesIO(image_file_content) image_stream = io.BytesIO(image_file_content)
image = Image.open(image_stream) image = Image.open(image_stream)
@@ -107,8 +136,8 @@ def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
# Loop over all the visible nodes and draw their bounding boxes # Loop over all the visible nodes and draw their bounding boxes
for _node in nodes: for _node in nodes:
coords_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord') coords_str = _node.attrib.get('{{{:}}}screencoord'.format(_component_ns))
size_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') size_str = _node.attrib.get('{{{:}}}size'.format(_component_ns))
if coords_str and size_str: if coords_str and size_str:
try: try:
@@ -162,19 +191,15 @@ def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
node_text = (_node.text if '"' not in _node.text \ node_text = (_node.text if '"' not in _node.text \
else '"{:}"'.format(_node.text.replace('"', '""')) else '"{:}"'.format(_node.text.replace('"', '""'))
) )
elif _node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") \
and _node.get("{uri:deskat:value.at-spi.gnome.org}value"): and _node.get("{{{:}}}value".format(_value_ns)):
node_text: str = _node.get("{uri:deskat:value.at-spi.gnome.org}value") node_text = _node.get("{{{:}}}value".format(_value_ns), "")
node_text = (node_text if '"' not in node_text \ node_text = (node_text if '"' not in node_text \
else '"{:}"'.format(node_text.replace('"', '""')) else '"{:}"'.format(node_text.replace('"', '""'))
) )
else: else:
node_text = '""' node_text = '""'
text_information: str = "{:d}\t{:}\t{:}\t{:}" \ text_information: str = "{:d}\t{:}\t{:}\t{:}".format(index, _node.tag, _node.get("name", ""), node_text)
.format(index, _node.tag
, _node.get("name", "")
, node_text
)
text_informations.append(text_information) text_informations.append(text_information)
index += 1 index += 1

View File

@@ -15,11 +15,11 @@ import dashscope
import google.generativeai as genai import google.generativeai as genai
import openai import openai
import requests import requests
from requests.exceptions import SSLError
import tiktoken import tiktoken
from PIL import Image from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from groq import Groq from groq import Groq
from requests.exceptions import SSLError
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
@@ -31,6 +31,17 @@ logger = logging.getLogger("desktopenv.agent")
pure_text_settings = ['a11y_tree'] pure_text_settings = ['a11y_tree']
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
# Function to encode the image # Function to encode the image
def encode_image(image_content): def encode_image(image_content):
@@ -57,35 +68,48 @@ def save_to_tmp_img_file(data_str):
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
if platform == "ubuntu":
_attributes_ns = attributes_ns_ubuntu
_state_ns = state_ns_ubuntu
_component_ns = component_ns_ubuntu
_value_ns = value_ns_ubuntu
elif platform == "windows":
_attributes_ns = attributes_ns_windows
_state_ns = state_ns_windows
_component_ns = component_ns_windows
_value_ns = value_ns_windows
else:
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
linearized_accessibility_tree = ["tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"]
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
# Linearize the accessibility tree nodes into a table format # Linearize the accessibility tree nodes into a table format
for node in filtered_nodes: for node in filtered_nodes:
# linearized_accessibility_tree += node.tag + "\t"
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text: if node.text:
text = (node.text if '"' not in node.text \ text = (
else '"{:}"'.format(node.text.replace('"', '""')) node.text if '"' not in node.text \
) else '"{:}"'.format(node.text.replace('"', '""'))
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ )
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value") elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") \
text = (text if '"' not in text \ and node.get("{{{:}}}value".format(_value_ns)):
else '"{:}"'.format(text.replace('"', '""')) node_text = node.get("{{{:}}}value".format(_value_ns), "")
text = (node_text if '"' not in node_text \
else '"{:}"'.format(node_text.replace('"', '""'))
) )
else: else:
text = '""' text = '""'
# linearized_accessibility_tree += node.attrib.get(
# , "") + "\t"
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
linearized_accessibility_tree.append( linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}".format( "{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text node.tag, node.get("name", ""),
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "") text,
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "") node.get("{{{:}}}class".format(_attributes_ns), "") if platform == "ubuntu" else node.get("{{{:}}}class".format(class_ns_windows), ""),
node.get("{{{:}}}description".format(_attributes_ns), ""),
node.get('{{{:}}}screencoord'.format(_component_ns), ""),
node.get('{{{:}}}size'.format(_component_ns), "")
) )
) )
@@ -957,8 +981,9 @@ class PromptAgent:
} }
assert len(message["content"]) in [1, 2], "One text, or one text with one image" assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]: for part in message["content"]:
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[ qwen_message['content'].append(
'type'] == "image_url" else None {"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message) qwen_messages.append(qwen_message)