Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim

This commit is contained in:
Timothyxxx
2024-03-25 18:02:48 +08:00
parent 635b6717b3
commit 172123ab2c
4 changed files with 104 additions and 43 deletions

View File

@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
# Get the screenshot and save to the image_path
screenshot = self.controller.get_screenshot()
max_retries = 20
for _ in range(max_retries):
screenshot = self.controller.get_screenshot()
if screenshot is not None:
break
time.sleep(1)
with open(image_path, "wb") as f:
f.write(screenshot)

View File

@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
# Value {{{ #
if hasattr(node, "get_step"):
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
try:
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
except:
pass
if hasattr(node, "value"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
except:
pass
if hasattr(node, "get_value"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
except:
pass
elif hasattr(node, "get_position"):
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
try:
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
except:
pass
if hasattr(node, "min_value"):
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
try:
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
except:
pass
elif hasattr(node, "get_range_min"):
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
try:
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
except:
pass
if hasattr(node, "max_value"):
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
try:
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
except:
pass
elif hasattr(node, "get_range_max"):
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
try:
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
except:
pass
# }}} Value #
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))

View File

@@ -110,6 +110,10 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
coords = tuple(map(int, coords_str.strip('()').split(', ')))
size = tuple(map(int, size_str.strip('()').split(', ')))
import copy
original_coords = copy.deepcopy(coords)
original_size = copy.deepcopy(size)
if float(down_sampling_ratio) != 1.0:
# Downsample the coordinates and size
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
@@ -145,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
# each mark is an x, y, w, h tuple
marks.append([coords[0], coords[1], size[0], size[1]])
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
drew_nodes.append(_node)
if _node.text:

View File

@@ -8,12 +8,14 @@ import uuid
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List, Tuple, Union
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument
@@ -32,49 +34,49 @@ def encode_image(image_path):
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree):
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
#linearized_accessibility_tree += node.tag + "\t"
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
# linearized_accessibility_tree += node.tag + "\t"
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
text = ( node.text if '"' not in node.text\
else '"{:}"'.format(node.text.replace('"', '""'))
)
text = (node.text if '"' not in node.text \
else '"{:}"'.format(node.text.replace('"', '""'))
)
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
text = (text if '"' not in text\
else '"{:}"'.format(text.replace('"', '""'))
)
text = (text if '"' not in text \
else '"{:}"'.format(text.replace('"', '""'))
)
else:
text = '""'
#linearized_accessibility_tree += node.attrib.get(
#, "") + "\t"
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
# linearized_accessibility_tree += node.attrib.get(
# , "") + "\t"
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
return "\n".join(linearized_accessibility_tree)
def tag_screenshot(screenshot, accessibility_tree):
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
# Make tag screenshot
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
@@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
return actions
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
enc = tiktoken.encoding_for_model("gpt-4")
tokens = enc.encode(linearized_accessibility_tree)
if len(tokens) > max_tokens:
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
linearized_accessibility_tree += "[...]\n"
return linearized_accessibility_tree
class PromptAgent:
def __init__(
self,
platform="ubuntu",
model="gpt-4-vision-preview",
max_tokens=1500,
top_p=0.9,
@@ -180,8 +191,10 @@ class PromptAgent:
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3
max_trajectory_length=3,
a11y_tree_max_tokens=10000
):
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
@@ -189,6 +202,7 @@ class PromptAgent:
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.thoughts = []
self.actions = []
@@ -349,9 +363,14 @@ class PromptAgent:
# {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
@@ -383,9 +402,14 @@ class PromptAgent:
]
})
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": None,
"accessibility_tree": linearized_accessibility_tree
@@ -403,10 +427,15 @@ class PromptAgent:
})
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
"accessibility_tree"], self.platform)
base64_image = encode_image(tagged_screenshot)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
@@ -435,7 +464,7 @@ class PromptAgent:
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
#logger.info("PROMPT: %s", messages)
# logger.info("PROMPT: %s", messages)
response = self.call_llm({
"model": self.model,
@@ -556,8 +585,6 @@ class PromptAgent:
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"max_tokens": max_tokens,
@@ -570,7 +597,8 @@ class PromptAgent:
attempt = 0
while attempt < max_attempts:
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
json=payload)
if response.status_code == 200:
result = response.json()['choices'][0]['message']['content']
break
@@ -581,7 +609,7 @@ class PromptAgent:
else:
print("Exceeded maximum attempts to call LLM.")
result = ""
return result
@@ -605,14 +633,13 @@ class PromptAgent:
mistral_messages.append(mistral_message)
from openai import OpenAI
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
flag = 0
while True:
try: