ver Mar19thv2

supplemented at info back for som setting
This commit is contained in:
David Chang
2024-03-19 18:41:55 +08:00
parent 05336a8ecf
commit 4df088e2ad
3 changed files with 59 additions and 26 deletions

View File

@@ -37,27 +37,36 @@ def linearize_accessibility_tree(accessibility_tree):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree))
linearized_accessibility_tree = "tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)\n"
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
linearized_accessibility_tree += node.tag + "\t"
linearized_accessibility_tree += node.attrib.get('name') + "\t"
#linearized_accessibility_tree += node.tag + "\t"
#linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
linearized_accessibility_tree += (node.text if '"' not in node.text else '"{:}"'.format(
node.text.replace('"', '""'))) + "\t"
text = ( node.text if '"' not in node.text\
else '"{:}"'.format(node.text.replace('"', '""'))
)
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
linearized_accessibility_tree += (text if '"' not in text else '"{:}"'.format(
text.replace('"', '""'))) + "\t"
text = (text if '"' not in text\
else '"{:}"'.format(text.replace('"', '""'))
)
else:
linearized_accessibility_tree += '""\t'
linearized_accessibility_tree += node.attrib.get(
'{uri:deskat:component.at-spi.gnome.org}screencoord', "") + "\t"
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
text = '""'
#linearized_accessibility_tree += node.attrib.get(
#, "") + "\t"
#linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
return linearized_accessibility_tree
return "\n".join(linearized_accessibility_tree)
def tag_screenshot(screenshot, accessibility_tree):
@@ -68,9 +77,9 @@ def tag_screenshot(screenshot, accessibility_tree):
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), check_image=True)
# Make tag screenshot
marks, drew_nodes = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
return marks, drew_nodes, tagged_screenshot_file_path
return marks, drew_nodes, tagged_screenshot_file_path, element_list
def parse_actions_from_string(input_string):
@@ -395,11 +404,13 @@ class PromptAgent:
})
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs["accessibility_tree"])
base64_image = encode_image(tagged_screenshot)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
self.observations.append({
"screenshot": base64_image
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
@@ -407,7 +418,8 @@ class PromptAgent:
"content": [
{
"type": "text",
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
@@ -774,7 +786,7 @@ class PromptAgent:
if response.status_code == HTTPStatus.OK:
try:
return response.json()['output']['choices'][0]['message']['content']
except Exception as e:
except Exception:
return ""
else:
print(response.code) # The error code.