Modify the namespace of a11y tree (#62)
This commit is contained in:
@@ -24,9 +24,21 @@ logger.setLevel(logging.INFO)
|
|||||||
|
|
||||||
MAX_RETRY_TIMES = 10
|
MAX_RETRY_TIMES = 10
|
||||||
RETRY_INTERVAL = 5
|
RETRY_INTERVAL = 5
|
||||||
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_arm/resolve/main/Ubuntu.zip"
|
UBUNTU_ARM_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-arm.zip"
|
||||||
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_x86/resolve/main/Ubuntu.zip"
|
UBUNTU_X86_URL = "https://huggingface.co/datasets/xlangai/ubuntu_osworld/resolve/main/Ubuntu-x86.zip"
|
||||||
DOWNLOADED_FILE_NAME = "Ubuntu.zip"
|
|
||||||
|
# Determine the platform and CPU architecture to decide the correct VM image to download
|
||||||
|
if platform.system() == 'Darwin': # macOS
|
||||||
|
# if os.uname().machine == 'arm64': # Apple Silicon
|
||||||
|
URL = UBUNTU_ARM_URL
|
||||||
|
# else:
|
||||||
|
# url = UBUNTU_X86_URL
|
||||||
|
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
||||||
|
URL = UBUNTU_X86_URL
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported platform or architecture")
|
||||||
|
|
||||||
|
DOWNLOADED_FILE_NAME = URL.split('/')[-1]
|
||||||
REGISTRY_PATH = '.vmware_vms'
|
REGISTRY_PATH = '.vmware_vms'
|
||||||
LOCK_FILE_NAME = '.vmware_lck'
|
LOCK_FILE_NAME = '.vmware_lck'
|
||||||
VMS_DIR = "./vmware_vm_data"
|
VMS_DIR = "./vmware_vm_data"
|
||||||
@@ -109,17 +121,6 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
|
|||||||
os.makedirs(vms_dir, exist_ok=True)
|
os.makedirs(vms_dir, exist_ok=True)
|
||||||
|
|
||||||
def __download_and_unzip_vm():
|
def __download_and_unzip_vm():
|
||||||
# Determine the platform and CPU architecture to decide the correct VM image to download
|
|
||||||
if platform.system() == 'Darwin': # macOS
|
|
||||||
# if os.uname().machine == 'arm64': # Apple Silicon
|
|
||||||
url = UBUNTU_ARM_URL
|
|
||||||
# else:
|
|
||||||
# url = UBUNTU_X86_URL
|
|
||||||
elif platform.machine().lower() in ['amd64', 'x86_64']:
|
|
||||||
url = UBUNTU_X86_URL
|
|
||||||
else:
|
|
||||||
raise Exception("Unsupported platform or architecture")
|
|
||||||
|
|
||||||
# Download the virtual machine image
|
# Download the virtual machine image
|
||||||
logger.info("Downloading the virtual machine image...")
|
logger.info("Downloading the virtual machine image...")
|
||||||
downloaded_size = 0
|
downloaded_size = 0
|
||||||
@@ -131,7 +132,7 @@ def _install_vm(vm_name, vms_dir, downloaded_file_name, original_vm_name="Ubuntu
|
|||||||
downloaded_size = os.path.getsize(downloaded_file_path)
|
downloaded_size = os.path.getsize(downloaded_file_path)
|
||||||
headers["Range"] = f"bytes={downloaded_size}-"
|
headers["Range"] = f"bytes={downloaded_size}-"
|
||||||
|
|
||||||
with requests.get(url, headers=headers, stream=True) as response:
|
with requests.get(URL, headers=headers, stream=True) as response:
|
||||||
if response.status_code == 416:
|
if response.status_code == 416:
|
||||||
# This means the range was not satisfiable, possibly the file was fully downloaded
|
# This means the range was not satisfiable, possibly the file was fully downloaded
|
||||||
logger.info("Fully downloaded or the file size changed.")
|
logger.info("Fully downloaded or the file size changed.")
|
||||||
|
|||||||
@@ -26,11 +26,25 @@ def find_leaf_nodes(xlm_file_str):
|
|||||||
return leaf_nodes
|
return leaf_nodes
|
||||||
|
|
||||||
|
|
||||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||||
|
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||||
|
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||||
|
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||||
|
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||||
|
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||||
|
|
||||||
|
|
||||||
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||||
|
if platform == "ubuntu":
|
||||||
|
_state_ns = state_ns_ubuntu
|
||||||
|
_component_ns = component_ns_ubuntu
|
||||||
|
elif platform == "windows":
|
||||||
|
_state_ns = state_ns_windows
|
||||||
|
_component_ns = component_ns_windows
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||||
|
|
||||||
keeps: bool = node.tag.startswith("document") \
|
keeps: bool = node.tag.startswith("document") \
|
||||||
or node.tag.endswith("item") \
|
or node.tag.endswith("item") \
|
||||||
or node.tag.endswith("button") \
|
or node.tag.endswith("button") \
|
||||||
@@ -53,23 +67,26 @@ def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
|||||||
, "traydummysearchcontrol", "uiimage", "uiproperty"
|
, "traydummysearchcontrol", "uiimage", "uiproperty"
|
||||||
, "uiribboncommandbar"
|
, "uiribboncommandbar"
|
||||||
}
|
}
|
||||||
keeps = keeps and (platform == "ubuntu" \
|
keeps = keeps and (
|
||||||
and node.get("{{{:}}}showing".format(state_ns), "false") == "true" \
|
platform == "ubuntu"
|
||||||
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \
|
and node.get("{{{:}}}showing".format(_state_ns), "false") == "true"
|
||||||
or platform == "windows" \
|
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||||
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \
|
or platform == "windows"
|
||||||
) \
|
and node.get("{{{:}}}visible".format(_state_ns), "false") == "true"
|
||||||
and (node.get("{{{:}}}enabled".format(state_ns), "false") == "true" \
|
) \
|
||||||
or node.get("{{{:}}}editable".format(state_ns), "false") == "true" \
|
and (
|
||||||
or node.get("{{{:}}}expandable".format(state_ns), "false") == "true" \
|
node.get("{{{:}}}enabled".format(_state_ns), "false") == "true"
|
||||||
or node.get("{{{:}}}checkable".format(state_ns), "false") == "true"
|
or node.get("{{{:}}}editable".format(_state_ns), "false") == "true"
|
||||||
) \
|
or node.get("{{{:}}}expandable".format(_state_ns), "false") == "true"
|
||||||
and (node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \
|
or node.get("{{{:}}}checkable".format(_state_ns), "false") == "true"
|
||||||
or check_image and node.get("image", "false") == "true"
|
) \
|
||||||
)
|
and (
|
||||||
|
node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \
|
||||||
|
or check_image and node.get("image", "false") == "true"
|
||||||
|
)
|
||||||
|
|
||||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(_component_ns), "(-1, -1)"))
|
||||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(_component_ns), "(-1, -1)"))
|
||||||
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
|
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
|
||||||
return keeps
|
return keeps
|
||||||
|
|
||||||
@@ -85,7 +102,19 @@ def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
|||||||
return filtered_nodes
|
return filtered_nodes
|
||||||
|
|
||||||
|
|
||||||
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
|
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0, platform="ubuntu"):
|
||||||
|
|
||||||
|
if platform == "ubuntu":
|
||||||
|
_state_ns = state_ns_ubuntu
|
||||||
|
_component_ns = component_ns_ubuntu
|
||||||
|
_value_ns = value_ns_ubuntu
|
||||||
|
elif platform == "windows":
|
||||||
|
_state_ns = state_ns_windows
|
||||||
|
_component_ns = component_ns_windows
|
||||||
|
_value_ns = value_ns_windows
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||||
|
|
||||||
# Load the screenshot image
|
# Load the screenshot image
|
||||||
image_stream = io.BytesIO(image_file_content)
|
image_stream = io.BytesIO(image_file_content)
|
||||||
image = Image.open(image_stream)
|
image = Image.open(image_stream)
|
||||||
@@ -107,8 +136,8 @@ def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
|
|||||||
|
|
||||||
# Loop over all the visible nodes and draw their bounding boxes
|
# Loop over all the visible nodes and draw their bounding boxes
|
||||||
for _node in nodes:
|
for _node in nodes:
|
||||||
coords_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord')
|
coords_str = _node.attrib.get('{{{:}}}screencoord'.format(_component_ns))
|
||||||
size_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size')
|
size_str = _node.attrib.get('{{{:}}}size'.format(_component_ns))
|
||||||
|
|
||||||
if coords_str and size_str:
|
if coords_str and size_str:
|
||||||
try:
|
try:
|
||||||
@@ -162,19 +191,15 @@ def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
|
|||||||
node_text = (_node.text if '"' not in _node.text \
|
node_text = (_node.text if '"' not in _node.text \
|
||||||
else '"{:}"'.format(_node.text.replace('"', '""'))
|
else '"{:}"'.format(_node.text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
elif _node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
elif _node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") \
|
||||||
and _node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
and _node.get("{{{:}}}value".format(_value_ns)):
|
||||||
node_text: str = _node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
node_text = _node.get("{{{:}}}value".format(_value_ns), "")
|
||||||
node_text = (node_text if '"' not in node_text \
|
node_text = (node_text if '"' not in node_text \
|
||||||
else '"{:}"'.format(node_text.replace('"', '""'))
|
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
node_text = '""'
|
node_text = '""'
|
||||||
text_information: str = "{:d}\t{:}\t{:}\t{:}" \
|
text_information: str = "{:d}\t{:}\t{:}\t{:}".format(index, _node.tag, _node.get("name", ""), node_text)
|
||||||
.format(index, _node.tag
|
|
||||||
, _node.get("name", "")
|
|
||||||
, node_text
|
|
||||||
)
|
|
||||||
text_informations.append(text_information)
|
text_informations.append(text_information)
|
||||||
|
|
||||||
index += 1
|
index += 1
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ import dashscope
|
|||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
from requests.exceptions import SSLError
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
||||||
from groq import Groq
|
from groq import Groq
|
||||||
|
from requests.exceptions import SSLError
|
||||||
|
|
||||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||||
@@ -31,6 +31,17 @@ logger = logging.getLogger("desktopenv.agent")
|
|||||||
|
|
||||||
pure_text_settings = ['a11y_tree']
|
pure_text_settings = ['a11y_tree']
|
||||||
|
|
||||||
|
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||||
|
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||||
|
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||||
|
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||||
|
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||||
|
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||||
|
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||||
|
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||||
|
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||||
|
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||||||
|
|
||||||
|
|
||||||
# Function to encode the image
|
# Function to encode the image
|
||||||
def encode_image(image_content):
|
def encode_image(image_content):
|
||||||
@@ -57,35 +68,48 @@ def save_to_tmp_img_file(data_str):
|
|||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
|
||||||
|
if platform == "ubuntu":
|
||||||
|
_attributes_ns = attributes_ns_ubuntu
|
||||||
|
_state_ns = state_ns_ubuntu
|
||||||
|
_component_ns = component_ns_ubuntu
|
||||||
|
_value_ns = value_ns_ubuntu
|
||||||
|
elif platform == "windows":
|
||||||
|
_attributes_ns = attributes_ns_windows
|
||||||
|
_state_ns = state_ns_windows
|
||||||
|
_component_ns = component_ns_windows
|
||||||
|
_value_ns = value_ns_windows
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||||||
|
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||||
|
linearized_accessibility_tree = ["tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"]
|
||||||
|
|
||||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
|
||||||
# Linearize the accessibility tree nodes into a table format
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
for node in filtered_nodes:
|
for node in filtered_nodes:
|
||||||
# linearized_accessibility_tree += node.tag + "\t"
|
|
||||||
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
|
||||||
if node.text:
|
if node.text:
|
||||||
text = (node.text if '"' not in node.text \
|
text = (
|
||||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
node.text if '"' not in node.text \
|
||||||
)
|
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
)
|
||||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
|
||||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith("EditWrapper") \
|
||||||
text = (text if '"' not in text \
|
and node.get("{{{:}}}value".format(_value_ns)):
|
||||||
else '"{:}"'.format(text.replace('"', '""'))
|
node_text = node.get("{{{:}}}value".format(_value_ns), "")
|
||||||
|
text = (node_text if '"' not in node_text \
|
||||||
|
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = '""'
|
text = '""'
|
||||||
# linearized_accessibility_tree += node.attrib.get(
|
|
||||||
# , "") + "\t"
|
|
||||||
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
|
||||||
linearized_accessibility_tree.append(
|
linearized_accessibility_tree.append(
|
||||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||||
node.tag, node.get("name", ""), text
|
node.tag, node.get("name", ""),
|
||||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
text,
|
||||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
node.get("{{{:}}}class".format(_attributes_ns), "") if platform == "ubuntu" else node.get("{{{:}}}class".format(class_ns_windows), ""),
|
||||||
|
node.get("{{{:}}}description".format(_attributes_ns), ""),
|
||||||
|
node.get('{{{:}}}screencoord'.format(_component_ns), ""),
|
||||||
|
node.get('{{{:}}}size'.format(_component_ns), "")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -957,8 +981,9 @@ class PromptAgent:
|
|||||||
}
|
}
|
||||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||||
for part in message["content"]:
|
for part in message["content"]:
|
||||||
qwen_message['content'].append({"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
qwen_message['content'].append(
|
||||||
'type'] == "image_url" else None
|
{"image": "file://" + save_to_tmp_img_file(part['image_url']['url'])}) if part[
|
||||||
|
'type'] == "image_url" else None
|
||||||
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
||||||
|
|
||||||
qwen_messages.append(qwen_message)
|
qwen_messages.append(qwen_message)
|
||||||
|
|||||||
Reference in New Issue
Block a user