Support downsampling; Fix bugs in windows a11y tree; Add a11y_tree trim
This commit is contained in:
@@ -8,12 +8,14 @@ import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument
|
||||
|
||||
@@ -32,49 +34,49 @@ def encode_image(image_path):
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree):
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
|
||||
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
||||
# Linearize the accessibility tree nodes into a table format
|
||||
|
||||
for node in filtered_nodes:
|
||||
#linearized_accessibility_tree += node.tag + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
# linearized_accessibility_tree += node.tag + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||
if node.text:
|
||||
text = ( node.text if '"' not in node.text\
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
text = (node.text if '"' not in node.text \
|
||||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||||
)
|
||||
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
text = (text if '"' not in text\
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
text = (text if '"' not in text \
|
||||
else '"{:}"'.format(text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
text = '""'
|
||||
#linearized_accessibility_tree += node.attrib.get(
|
||||
#, "") + "\t"
|
||||
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
# linearized_accessibility_tree += node.attrib.get(
|
||||
# , "") + "\t"
|
||||
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
||||
linearized_accessibility_tree.append(
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||||
node.tag, node.get("name", ""), text
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
||||
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(linearized_accessibility_tree)
|
||||
|
||||
|
||||
def tag_screenshot(screenshot, accessibility_tree):
|
||||
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||
# Creat a tmp file to store the screenshot in random name
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
|
||||
@@ -170,9 +172,18 @@ def parse_code_from_som_string(input_string, masks):
|
||||
return actions
|
||||
|
||||
|
||||
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||||
enc = tiktoken.encoding_for_model("gpt-4")
|
||||
tokens = enc.encode(linearized_accessibility_tree)
|
||||
if len(tokens) > max_tokens:
|
||||
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||||
linearized_accessibility_tree += "[...]\n"
|
||||
return linearized_accessibility_tree
|
||||
|
||||
class PromptAgent:
|
||||
def __init__(
|
||||
self,
|
||||
platform="ubuntu",
|
||||
model="gpt-4-vision-preview",
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
@@ -180,8 +191,10 @@ class PromptAgent:
|
||||
action_space="computer_13",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=3
|
||||
max_trajectory_length=3,
|
||||
a11y_tree_max_tokens=10000
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
@@ -189,6 +202,7 @@ class PromptAgent:
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
@@ -349,9 +363,14 @@ class PromptAgent:
|
||||
# {{{1
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"]) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
@@ -383,9 +402,14 @@ class PromptAgent:
|
||||
]
|
||||
})
|
||||
elif self.observation_type == "a11y_tree":
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"])
|
||||
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
||||
platform=self.platform)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": None,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -403,10 +427,15 @@ class PromptAgent:
|
||||
})
|
||||
elif self.observation_type == "som":
|
||||
# Add som to the screenshot
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
|
||||
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
||||
"accessibility_tree"], self.platform)
|
||||
base64_image = encode_image(tagged_screenshot)
|
||||
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||||
|
||||
if linearized_accessibility_tree:
|
||||
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
||||
self.a11y_tree_max_tokens)
|
||||
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree
|
||||
@@ -435,7 +464,7 @@ class PromptAgent:
|
||||
# with open("messages.json", "w") as f:
|
||||
# f.write(json.dumps(messages, indent=4))
|
||||
|
||||
#logger.info("PROMPT: %s", messages)
|
||||
# logger.info("PROMPT: %s", messages)
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
@@ -556,8 +585,6 @@ class PromptAgent:
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -570,7 +597,8 @@ class PromptAgent:
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
# response = requests.post("https://api.aigcbest.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers, json=payload)
|
||||
response = requests.post("https://token.cluade-chat.top/v1/chat/completions", headers=headers,
|
||||
json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()['choices'][0]['message']['content']
|
||||
break
|
||||
@@ -581,7 +609,7 @@ class PromptAgent:
|
||||
else:
|
||||
print("Exceeded maximum attempts to call LLM.")
|
||||
result = ""
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -605,14 +633,13 @@ class PromptAgent:
|
||||
|
||||
mistral_messages.append(mistral_message)
|
||||
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
||||
base_url='https://api.together.xyz',
|
||||
)
|
||||
logger.info("Generating content with Mistral model: %s", self.model)
|
||||
|
||||
|
||||
flag = 0
|
||||
while True:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user