Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim
This commit is contained in:
@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
|
||||
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||
|
||||
# Get the screenshot and save to the image_path
|
||||
screenshot = self.controller.get_screenshot()
|
||||
max_retries = 20
|
||||
for _ in range(max_retries):
|
||||
screenshot = self.controller.get_screenshot()
|
||||
if screenshot is not None:
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
|
||||
@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
|
||||
|
||||
# Value {{{ #
|
||||
if hasattr(node, "get_step"):
|
||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||
try:
|
||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "value"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "get_value"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_position"):
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||
try:
|
||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "min_value"):
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_range_min"):
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||
try:
|
||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||
except:
|
||||
pass
|
||||
if hasattr(node, "max_value"):
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||
try:
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||
except:
|
||||
pass
|
||||
elif hasattr(node, "get_range_max"):
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||
try:
|
||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||
except:
|
||||
pass
|
||||
# }}} Value #
|
||||
|
||||
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
||||
|
||||
@@ -110,6 +110,10 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
||||
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
||||
size = tuple(map(int, size_str.strip('()').split(', ')))
|
||||
|
||||
import copy
|
||||
original_coords = copy.deepcopy(coords)
|
||||
original_size = copy.deepcopy(size)
|
||||
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
# Downsample the coordinates and size
|
||||
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||
@@ -145,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||
|
||||
# each mark is an x, y, w, h tuple
|
||||
marks.append([coords[0], coords[1], size[0], size[1]])
|
||||
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||
drew_nodes.append(_node)
|
||||
|
||||
if _node.text:
|
||||
|
||||
@@ -8,12 +8,14 @@ import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
@@ -32,49 +34,49 @@ def encode_image(image_path):
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
#linearized_accessibility_tree += node.tag + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
# linearized_accessibility_tree += node.tag + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
if node.text:
|
||||
text = ( node.text if '"' not in node.text\
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
text = (node.text if '"' not in node.text \
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
text = (text if '"' not in text\
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
text = (text if '"' not in text \
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
text = '""'
|
||||
#linearized_accessibility_tree += node.attrib.get(
|
||||
#, "") + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
# linearized_accessibility_tree += node.attrib.get(
|
||||
# , "") + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(linearized_accessibility_tree)
|
||||
|
||||
|
||||
def tag_screenshot(screenshot, accessibility_tree):
|
||||
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||
# Creat a tmp file to store the screenshot in random name
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
@@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
|
||||
return actions
|
||||
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||
enc = tiktoken.encoding_for_model("gpt-4")
|
||||
tokens = enc.encode(linearized_accessibility_tree)
|
||||
if len(tokens) > max_tokens:
|
||||
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||
linearized_accessibility_tree += "[...]\n"
|
||||
return linearized_accessibility_tree
|
||||
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
platform="ubuntu",
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
@@ -180,8 +191,10 @@ class PromptAgent:
|
||||
action_space="computer_13",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_tokens=10000
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -189,6 +202,7 @@ class PromptAgent:
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
@@ -349,9 +363,14 @@ class PromptAgent:
|
||||
# {{{1
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
@@ -383,9 +402,14 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": None,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -403,10 +427,15 @@ class PromptAgent:
|
||||
})
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||
"accessibility_tree"], self.platform)
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -435,7 +464,7 @@ class PromptAgent:
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
#logger.info("PROMPT: %s", messages)
|
||||
# logger.info("PROMPT: %s", messages)
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
@@ -556,8 +585,6 @@ class PromptAgent:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -570,7 +597,8 @@ class PromptAgent:
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||
json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()['choices'][0]['message']['content']
|
||||
break
|
||||
@@ -581,7 +609,7 @@ class PromptAgent:
|
||||
else:
|
||||
print("Exceeded maximum attempts to call LLM.")
|
||||
result = ""
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -605,14 +633,13 @@ class PromptAgent:
|
||||
|
||||
mistral_messages.append(mistral_message)
|
||||
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
|
||||
|
||||
flag = 0
|
||||
while True:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user