Modify the namespace of a11y tree (#62)
This commit is contained in:
@@ -15,11 +15,11 @@ import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
from requests.exceptions import SSLError
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
||||
from groq import Groq
|
||||
from requests.exceptions import SSLError
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
@@ -31,6 +31,17 @@ logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
pure_text_settings = ['a11y_tree']
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_content):
|
||||
@@ -57,35 +68,48 @@ def save_to_tmp_img_file(data_str):
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
|
||||
if platform == "ubuntu":
|
||||
_attributes_ns = attributes_ns_ubuntu
|
||||
_state_ns = state_ns_ubuntu
|
||||
_component_ns = component_ns_ubuntu
|
||||
_value_ns = value_ns_ubuntu
|
||||
elif platform == "windows":
|
||||
_attributes_ns = attributes_ns_windows
|
||||
_state_ns = state_ns_windows
|
||||
_component_ns = component_ns_windows
|
||||
_value_ns = value_ns_windows
|
||||
else:
|
||||
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"]
|
||||
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
# linearized_accessibility_tree += node.tag + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
if node.text:
|
||||
text = (node.text if '"' not in node.text \
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
text = (text if '"' not in text \
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
text = (
|
||||
node.text if '"' not in node.text \
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
|
||||
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") \
|
||||
and node.get("{{{:}}}value".format(_value_ns)):
|
||||
node_text = node.get("{{{:}}}value".format(_value_ns), "")
|
||||
text = (node_text if '"' not in node_text \
|
||||
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
text = '""'
|
||||
# linearized_accessibility_tree += node.attrib.get(
|
||||
# , "") + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""),
|
||||
text,
|
||||
node.get("{{{:}}}class".format(_attributes_ns), "") if platform == "ubuntu" else node.get("{{{:}}}class".format(class_ns_windows), ""),
|
||||
node.get("{{{:}}}description".format(_attributes_ns), ""),
|
||||
node.get('{{{:}}}screencoord'.format(_component_ns), ""),
|
||||
node.get('{{{:}}}size'.format(_component_ns), "")
|
||||
)
|
||||
)
|
||||
|
||||
@@ -957,8 +981,9 @@ class PromptAgent:
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
for part in message["content"]:
|
||||
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append(
|
||||
{"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
||||
'type'] == "image_url" else None
|
||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||
|
||||
qwen_messages.append(qwen_message)
|
||||
|
||||
Reference in New Issue
Block a user