Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim
This commit is contained in:
@@ -146,7 +146,13 @@ class DesktopEnv(gym.Env):
|
|||||||
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||||
|
|
||||||
# Get the screenshot and save to the image_path
|
# Get the screenshot and save to the image_path
|
||||||
|
max_retries = 20
|
||||||
|
for _ in range(max_retries):
|
||||||
screenshot = self.controller.get_screenshot()
|
screenshot = self.controller.get_screenshot()
|
||||||
|
if screenshot is not None:
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
with open(image_path, "wb") as f:
|
with open(image_path, "wb") as f:
|
||||||
f.write(screenshot)
|
f.write(screenshot)
|
||||||
|
|
||||||
|
|||||||
@@ -531,21 +531,45 @@ def _create_pywinauto_node(node: BaseWrapper, depth: int = 0, flag: Optional[str
|
|||||||
|
|
||||||
# Value {{{ #
|
# Value {{{ #
|
||||||
if hasattr(node, "get_step"):
|
if hasattr(node, "get_step"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
attribute_dict["{{{:}}}step".format(_accessibility_ns_map["val"])] = str(node.get_step())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "value"):
|
if hasattr(node, "value"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "get_value"):
|
if hasattr(node, "get_value"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_position"):
|
elif hasattr(node, "get_position"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
attribute_dict["{{{:}}}value".format(_accessibility_ns_map["val"])] = str(node.get_position())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "min_value"):
|
if hasattr(node, "min_value"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.min_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_range_min"):
|
elif hasattr(node, "get_range_min"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
attribute_dict["{{{:}}}min".format(_accessibility_ns_map["val"])] = str(node.get_range_min())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
if hasattr(node, "max_value"):
|
if hasattr(node, "max_value"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.max_value())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
elif hasattr(node, "get_range_max"):
|
elif hasattr(node, "get_range_max"):
|
||||||
|
try:
|
||||||
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
attribute_dict["{{{:}}}max".format(_accessibility_ns_map["val"])] = str(node.get_range_max())
|
||||||
|
except:
|
||||||
|
pass
|
||||||
# }}} Value #
|
# }}} Value #
|
||||||
|
|
||||||
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
attribute_dict["{{{:}}}class".format(_accessibility_ns_map["win"])] = str(type(node))
|
||||||
|
|||||||
@@ -110,6 +110,10 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
|||||||
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
||||||
size = tuple(map(int, size_str.strip('()').split(', ')))
|
size = tuple(map(int, size_str.strip('()').split(', ')))
|
||||||
|
|
||||||
|
import copy
|
||||||
|
original_coords = copy.deepcopy(coords)
|
||||||
|
original_size = copy.deepcopy(size)
|
||||||
|
|
||||||
if float(down_sampling_ratio) != 1.0:
|
if float(down_sampling_ratio) != 1.0:
|
||||||
# Downsample the coordinates and size
|
# Downsample the coordinates and size
|
||||||
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
coords = tuple(int(coord * down_sampling_ratio) for coord in coords)
|
||||||
@@ -145,7 +149,7 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
|||||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||||
|
|
||||||
# each mark is an x, y, w, h tuple
|
# each mark is an x, y, w, h tuple
|
||||||
marks.append([coords[0], coords[1], size[0], size[1]])
|
marks.append([original_coords[0], original_coords[1], original_size[0], original_size[1]])
|
||||||
drew_nodes.append(_node)
|
drew_nodes.append(_node)
|
||||||
|
|
||||||
if _node.text:
|
if _node.text:
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ import uuid
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List
|
||||||
|
|
||||||
import backoff
|
import backoff
|
||||||
import dashscope
|
import dashscope
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
|
import tiktoken
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from google.api_core.exceptions import InvalidArgument
|
from google.api_core.exceptions import InvalidArgument
|
||||||
|
|
||||||
@@ -32,31 +34,31 @@ def encode_image(image_path):
|
|||||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
def linearize_accessibility_tree(accessibility_tree):
|
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||||
|
|
||||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||||
# Linearize the accessibility tree nodes into a table format
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
for node in filtered_nodes:
|
for node in filtered_nodes:
|
||||||
#linearized_accessibility_tree += node.tag + "\t"
|
# linearized_accessibility_tree += node.tag + "\t"
|
||||||
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||||
if node.text:
|
if node.text:
|
||||||
text = ( node.text if '"' not in node.text\
|
text = (node.text if '"' not in node.text \
|
||||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||||
text = (text if '"' not in text\
|
text = (text if '"' not in text \
|
||||||
else '"{:}"'.format(text.replace('"', '""'))
|
else '"{:}"'.format(text.replace('"', '""'))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
text = '""'
|
text = '""'
|
||||||
#linearized_accessibility_tree += node.attrib.get(
|
# linearized_accessibility_tree += node.attrib.get(
|
||||||
#, "") + "\t"
|
# , "") + "\t"
|
||||||
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||||
linearized_accessibility_tree.append(
|
linearized_accessibility_tree.append(
|
||||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||||
node.tag, node.get("name", ""), text
|
node.tag, node.get("name", ""), text
|
||||||
@@ -68,13 +70,13 @@ def linearize_accessibility_tree(accessibility_tree):
|
|||||||
return "\n".join(linearized_accessibility_tree)
|
return "\n".join(linearized_accessibility_tree)
|
||||||
|
|
||||||
|
|
||||||
def tag_screenshot(screenshot, accessibility_tree):
|
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||||
# Creat a tmp file to store the screenshot in random name
|
# Creat a tmp file to store the screenshot in random name
|
||||||
uuid_str = str(uuid.uuid4())
|
uuid_str = str(uuid.uuid4())
|
||||||
os.makedirs("tmp/images", exist_ok=True)
|
os.makedirs("tmp/images", exist_ok=True)
|
||||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||||
# Make tag screenshot
|
# Make tag screenshot
|
||||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||||
|
|
||||||
@@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
|
|||||||
return actions
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||||
|
enc = tiktoken.encoding_for_model("gpt-4")
|
||||||
|
tokens = enc.encode(linearized_accessibility_tree)
|
||||||
|
if len(tokens) > max_tokens:
|
||||||
|
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||||
|
linearized_accessibility_tree += "[...]\n"
|
||||||
|
return linearized_accessibility_tree
|
||||||
|
|
||||||
class PromptAgent:
|
class PromptAgent:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
platform="ubuntu",
|
||||||
model="gpt-4-vision-preview",
|
model="gpt-4-vision-preview",
|
||||||
max_tokens=1500,
|
max_tokens=1500,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
@@ -180,8 +191,10 @@ class PromptAgent:
|
|||||||
action_space="computer_13",
|
action_space="computer_13",
|
||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot_a11y_tree",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=3
|
max_trajectory_length=3,
|
||||||
|
a11y_tree_max_tokens=10000
|
||||||
):
|
):
|
||||||
|
self.platform = platform
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.top_p = top_p
|
self.top_p = top_p
|
||||||
@@ -189,6 +202,7 @@ class PromptAgent:
|
|||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.observation_type = observation_type
|
self.observation_type = observation_type
|
||||||
self.max_trajectory_length = max_trajectory_length
|
self.max_trajectory_length = max_trajectory_length
|
||||||
|
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
@@ -349,9 +363,14 @@ class PromptAgent:
|
|||||||
# {{{1
|
# {{{1
|
||||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||||
|
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
if self.observation_type == "screenshot_a11y_tree":
|
if self.observation_type == "screenshot_a11y_tree":
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
@@ -383,9 +402,14 @@ class PromptAgent:
|
|||||||
]
|
]
|
||||||
})
|
})
|
||||||
elif self.observation_type == "a11y_tree":
|
elif self.observation_type == "a11y_tree":
|
||||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||||
|
platform=self.platform)
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": None,
|
"screenshot": None,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -403,10 +427,15 @@ class PromptAgent:
|
|||||||
})
|
})
|
||||||
elif self.observation_type == "som":
|
elif self.observation_type == "som":
|
||||||
# Add som to the screenshot
|
# Add som to the screenshot
|
||||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||||
|
"accessibility_tree"], self.platform)
|
||||||
base64_image = encode_image(tagged_screenshot)
|
base64_image = encode_image(tagged_screenshot)
|
||||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||||
|
|
||||||
|
if linearized_accessibility_tree:
|
||||||
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||||
|
self.a11y_tree_max_tokens)
|
||||||
|
|
||||||
self.observations.append({
|
self.observations.append({
|
||||||
"screenshot": base64_image,
|
"screenshot": base64_image,
|
||||||
"accessibility_tree": linearized_accessibility_tree
|
"accessibility_tree": linearized_accessibility_tree
|
||||||
@@ -435,7 +464,7 @@ class PromptAgent:
|
|||||||
# with open("messages.json", "w") as f:
|
# with open("messages.json", "w") as f:
|
||||||
# f.write(json.dumps(messages, indent=4))
|
# f.write(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
#logger.info("PROMPT: %s", messages)
|
# logger.info("PROMPT: %s", messages)
|
||||||
|
|
||||||
response = self.call_llm({
|
response = self.call_llm({
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
@@ -556,8 +585,6 @@ class PromptAgent:
|
|||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
@@ -570,7 +597,8 @@ class PromptAgent:
|
|||||||
attempt = 0
|
attempt = 0
|
||||||
while attempt < max_attempts:
|
while attempt < max_attempts:
|
||||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
|
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||||
|
json=payload)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()['choices'][0]['message']['content']
|
result = response.json()['choices'][0]['message']['content']
|
||||||
break
|
break
|
||||||
@@ -605,7 +633,6 @@ class PromptAgent:
|
|||||||
|
|
||||||
mistral_messages.append(mistral_message)
|
mistral_messages.append(mistral_message)
|
||||||
|
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||||
|
|||||||
Reference in New Issue
Block a user