Implement heuristic cutting on the accessibility tree to get the important nodes; Finish accessibility tree text agent
This commit is contained in:
@@ -83,6 +83,8 @@ class PythonController:
|
|||||||
"""
|
"""
|
||||||
Executes an action on the server computer.
|
Executes an action on the server computer.
|
||||||
"""
|
"""
|
||||||
|
if action in ['WAIT', 'FAIL', 'DONE']:
|
||||||
|
return
|
||||||
|
|
||||||
action_type = action["action_type"]
|
action_type = action["action_type"]
|
||||||
parameters = action["parameters"] if "parameters" in action else {}
|
parameters = action["parameters"] if "parameters" in action else {}
|
||||||
|
|||||||
@@ -111,17 +111,17 @@ def run_one_example(example, agent, max_steps=10, example_trajectory_dir="exp_tr
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
action_space = "pyautogui"
|
action_space = "pyautogui"
|
||||||
example_class = "chrome"
|
example_class = "chrome"
|
||||||
example_id = "bb5e4c0d-f964-439c-97b6-bdb9747de3f4"
|
example_id = "06fe7178-4491-4589-810f-2e2bc9502122"
|
||||||
|
|
||||||
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f:
|
with open(f"evaluation_examples/examples/{example_class}/{example_id}.json", "r") as f:
|
||||||
example = json.load(f)
|
example = json.load(f)
|
||||||
example["snapshot"] = "exp_setup2"
|
example["snapshot"] = "exp_setup4"
|
||||||
|
|
||||||
# api_key = os.environ.get("OPENAI_API_KEY")
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
# agent = GPT4v_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
agent = GPT4_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
||||||
|
|
||||||
api_key = os.environ.get("GENAI_API_KEY")
|
# api_key = os.environ.get("GENAI_API_KEY")
|
||||||
agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
# agent = GeminiPro_Agent(api_key=api_key, instruction=example['instruction'], action_space=action_space)
|
||||||
|
|
||||||
root_trajectory_dir = "exp_trajectory"
|
root_trajectory_dir = "exp_trajectory"
|
||||||
|
|
||||||
|
|||||||
0
mm_agents/accessibility_tree_wrap/__init__.py
Normal file
0
mm_agents/accessibility_tree_wrap/__init__.py
Normal file
115
mm_agents/accessibility_tree_wrap/heuristic_retrieve.py
Normal file
115
mm_agents/accessibility_tree_wrap/heuristic_retrieve.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
|
||||||
|
def find_leaf_nodes(xml_file_path):
|
||||||
|
root = ET.fromstring(xml_file_path)
|
||||||
|
|
||||||
|
# Recursive function to traverse the XML tree and collect leaf nodes
|
||||||
|
def collect_leaf_nodes(node, leaf_nodes):
|
||||||
|
# If the node has no children, it is a leaf node, add it to the list
|
||||||
|
if not list(node):
|
||||||
|
leaf_nodes.append(node)
|
||||||
|
# If the node has children, recurse on each child
|
||||||
|
for child in node:
|
||||||
|
collect_leaf_nodes(child, leaf_nodes)
|
||||||
|
|
||||||
|
# List to hold all leaf nodes
|
||||||
|
leaf_nodes = []
|
||||||
|
collect_leaf_nodes(root, leaf_nodes)
|
||||||
|
return leaf_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def filter_nodes(nodes):
|
||||||
|
filtered_nodes = []
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if not node.get('{uri:deskat:state.at-spi.gnome.org}visible', None) == 'true':
|
||||||
|
# Not visible
|
||||||
|
continue
|
||||||
|
# Check if the node is a 'panel'
|
||||||
|
if node.tag == 'panel':
|
||||||
|
# Check if the 'panel' represents an interactive element
|
||||||
|
# or if it has certain attributes that are of interest.
|
||||||
|
# Add your conditions here...
|
||||||
|
if node.get('{uri:deskat:state.at-spi.gnome.org}focusable', 'false') == 'true':
|
||||||
|
filtered_nodes.append(node)
|
||||||
|
elif node.tag == 'text':
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
coords = tuple(map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord').strip('()').split(', ')))
|
||||||
|
if coords[0] < 0 or coords[1] < 0:
|
||||||
|
continue
|
||||||
|
size = tuple(map(int, node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size').strip('()').split(', ')))
|
||||||
|
if size[0] <= 0 or size[1] <= 0:
|
||||||
|
continue
|
||||||
|
# Node is not a 'panel', add to the list.
|
||||||
|
filtered_nodes.append(node)
|
||||||
|
|
||||||
|
return filtered_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path):
|
||||||
|
# Load the screenshot image
|
||||||
|
image = Image.open(image_file_path)
|
||||||
|
draw = ImageDraw.Draw(image)
|
||||||
|
|
||||||
|
# Optional: Load a font. If you don't specify a font, a default one will be used.
|
||||||
|
try:
|
||||||
|
# Adjust the path to the font file you have or use a default one
|
||||||
|
font = ImageFont.truetype("arial.ttf", 20)
|
||||||
|
except IOError:
|
||||||
|
# Fallback to a basic font if the specified font can't be loaded
|
||||||
|
font = ImageFont.load_default()
|
||||||
|
|
||||||
|
# Loop over all the visible nodes and draw their bounding boxes
|
||||||
|
for index, _node in enumerate(nodes):
|
||||||
|
coords_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}screencoord')
|
||||||
|
size_str = _node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size')
|
||||||
|
|
||||||
|
if coords_str and size_str:
|
||||||
|
try:
|
||||||
|
# Parse the coordinates and size from the strings
|
||||||
|
coords = tuple(map(int, coords_str.strip('()').split(', ')))
|
||||||
|
size = tuple(map(int, size_str.strip('()').split(', ')))
|
||||||
|
|
||||||
|
# Check for negative sizes
|
||||||
|
if size[0] <= 0 or size[1] <= 0:
|
||||||
|
raise ValueError(f"Size must be positive, got: {size}")
|
||||||
|
|
||||||
|
# Calculate the bottom-right corner of the bounding box
|
||||||
|
bottom_right = (coords[0] + size[0], coords[1] + size[1])
|
||||||
|
|
||||||
|
# Check that bottom_right > coords (x1 >= x0, y1 >= y0)
|
||||||
|
if bottom_right[0] < coords[0] or bottom_right[1] < coords[1]:
|
||||||
|
raise ValueError(f"Invalid coordinates or size, coords: {coords}, size: {size}")
|
||||||
|
|
||||||
|
# Draw rectangle on image
|
||||||
|
draw.rectangle([coords, bottom_right], outline="red", width=2)
|
||||||
|
|
||||||
|
# Draw index number at the bottom left of the bounding box
|
||||||
|
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
|
||||||
|
draw.text(text_position, str(index), font=font, fill="purple")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Save the result
|
||||||
|
image.save(output_image_file_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open('chrome_desktop_example_1.xml', 'r', encoding='utf-8') as f:
|
||||||
|
xml_string = f.read()
|
||||||
|
image_file_path = 'screenshot.png' # Replace with your actual screenshot image path
|
||||||
|
output_image_file_path = 'annotated_screenshot.png' # Replace with your desired output image path
|
||||||
|
|
||||||
|
leaf_nodes = find_leaf_nodes(xml_string)
|
||||||
|
filtered_nodes = filter_nodes(leaf_nodes)
|
||||||
|
print(f"Found {len(filtered_nodes)} filtered nodes")
|
||||||
|
|
||||||
|
for node in filtered_nodes:
|
||||||
|
print(node.tag, node.attrib)
|
||||||
|
|
||||||
|
draw_bounding_boxes(filtered_nodes, image_file_path, output_image_file_path)
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import PIL.Image
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
|
||||||
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
||||||
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
||||||
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
||||||
|
from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string
|
||||||
|
|
||||||
|
|
||||||
class GeminiPro_Agent:
|
class GeminiPro_Agent:
|
||||||
@@ -36,9 +37,25 @@ class GeminiPro_Agent:
|
|||||||
Only support single-round conversation, only fill-in the last desktop screenshot.
|
Only support single-round conversation, only fill-in the last desktop screenshot.
|
||||||
"""
|
"""
|
||||||
accessibility_tree = obs["accessibility_tree"]
|
accessibility_tree = obs["accessibility_tree"]
|
||||||
|
|
||||||
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
|
filtered_nodes = filter_nodes(leaf_nodes)
|
||||||
|
|
||||||
|
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
||||||
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
|
for node in filtered_nodes:
|
||||||
|
linearized_accessibility_tree += node.tag + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get(
|
||||||
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
||||||
|
|
||||||
self.trajectory.append({
|
self.trajectory.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": ["Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(accessibility_tree)]
|
"parts": [
|
||||||
|
"Given the XML format of accessibility tree (convert and formatted into table) as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||||
|
linearized_accessibility_tree)]
|
||||||
})
|
})
|
||||||
|
|
||||||
# todo: Remove this step once the Gemini supports multi-round conversation
|
# todo: Remove this step once the Gemini supports multi-round conversation
|
||||||
@@ -71,13 +88,20 @@ class GeminiPro_Agent:
|
|||||||
|
|
||||||
print("Trajectory:", traj_to_show)
|
print("Trajectory:", traj_to_show)
|
||||||
|
|
||||||
response = self.model.generate_content(
|
while True:
|
||||||
message_for_gemini,
|
try:
|
||||||
generation_config={
|
response = self.model.generate_content(
|
||||||
"max_output_tokens": self.max_tokens,
|
message_for_gemini,
|
||||||
"temperature": self.temperature
|
generation_config={
|
||||||
}
|
"max_output_tokens": self.max_tokens,
|
||||||
)
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
print("Failed to generate response, retrying...")
|
||||||
|
time.sleep(5)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
@@ -66,13 +67,20 @@ class GeminiProV_Agent:
|
|||||||
|
|
||||||
print("Trajectory:", traj_to_show)
|
print("Trajectory:", traj_to_show)
|
||||||
|
|
||||||
response = self.model.generate_content(
|
while True:
|
||||||
message_for_gemini,
|
try:
|
||||||
generation_config={
|
response = self.model.generate_content(
|
||||||
"max_output_tokens": self.max_tokens,
|
message_for_gemini,
|
||||||
"temperature": self.temperature
|
generation_config={
|
||||||
}
|
"max_output_tokens": self.max_tokens,
|
||||||
)
|
"temperature": self.temperature
|
||||||
|
}
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
print("Failed to generate response, retrying...")
|
||||||
|
time.sleep(5)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
||||||
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
||||||
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
||||||
|
|
||||||
@@ -81,9 +83,9 @@ class GPT4_Agent:
|
|||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": {
|
"text": {
|
||||||
"computer_13": SYS_PROMPT_ACTION,
|
"computer_13": SYS_PROMPT_ACTION,
|
||||||
"pyautogui": SYS_PROMPT_CODE
|
"pyautogui": SYS_PROMPT_CODE
|
||||||
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -94,12 +96,27 @@ class GPT4_Agent:
|
|||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
accessibility_tree = obs["accessibility_tree"]
|
accessibility_tree = obs["accessibility_tree"]
|
||||||
|
|
||||||
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
|
filtered_nodes = filter_nodes(leaf_nodes)
|
||||||
|
|
||||||
|
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
||||||
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
|
for node in filtered_nodes:
|
||||||
|
linearized_accessibility_tree += node.tag + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get(
|
||||||
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
||||||
|
|
||||||
self.trajectory.append({
|
self.trajectory.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(accessibility_tree)
|
"text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
||||||
|
linearized_accessibility_tree)
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
@@ -117,7 +134,16 @@ class GPT4_Agent:
|
|||||||
"messages": self.trajectory,
|
"messages": self.trajectory,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens
|
||||||
}
|
}
|
||||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers,
|
||||||
|
json=payload)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
print("Failed to generate response, retrying...")
|
||||||
|
time.sleep(5)
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ You ONLY need to return the code inside a code block, like this:
|
|||||||
```
|
```
|
||||||
Specially, it is also allowed to return the following special code:
|
Specially, it is also allowed to return the following special code:
|
||||||
When you think you have to wait for some time, return ```WAIT```;
|
When you think you have to wait for some time, return ```WAIT```;
|
||||||
When you think the task can not be done, return ```FAIL```;
|
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
|
||||||
When you think the task is done, return ```DONE```.
|
When you think the task is done, return ```DONE```.
|
||||||
|
|
||||||
First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
|
First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -81,9 +82,9 @@ class GPT4v_Agent:
|
|||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": {
|
"text": {
|
||||||
"computer_13": SYS_PROMPT_ACTION,
|
"computer_13": SYS_PROMPT_ACTION,
|
||||||
"pyautogui": SYS_PROMPT_CODE
|
"pyautogui": SYS_PROMPT_CODE
|
||||||
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
}[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction)
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -123,8 +124,16 @@ class GPT4v_Agent:
|
|||||||
"messages": self.trajectory,
|
"messages": self.trajectory,
|
||||||
"max_tokens": self.max_tokens
|
"max_tokens": self.max_tokens
|
||||||
}
|
}
|
||||||
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload)
|
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers,
|
||||||
|
json=payload)
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
print("Failed to generate response, retrying...")
|
||||||
|
time.sleep(5)
|
||||||
|
pass
|
||||||
try:
|
try:
|
||||||
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
actions = self.parse_actions(response.json()['choices'][0]['message']['content'])
|
||||||
except:
|
except:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ You ONLY need to return the code inside a code block, like this:
|
|||||||
```
|
```
|
||||||
Specially, it is also allowed to return the following special code:
|
Specially, it is also allowed to return the following special code:
|
||||||
When you think you have to wait for some time, return ```WAIT```;
|
When you think you have to wait for some time, return ```WAIT```;
|
||||||
When you think the task can not be done, return ```FAIL```;
|
When you think the task can not be done, return ```FAIL```, don't easily say ```FAIL```, try your best to do the task;
|
||||||
When you think the task is done, return ```DONE```.
|
When you think the task is done, return ```DONE```.
|
||||||
|
|
||||||
First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
|
First give the current screenshot and previous things we did a reflection, then RETURN ME THE CODE OR SPECIAL CODE I ASKED FOR. NEVER EVER RETURN ME ANYTHING ELSE.
|
||||||
|
|||||||
Reference in New Issue
Block a user