Support input screenshot and a11y tree altogether
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Dict, List
|
|||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes
|
||||||
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION
|
||||||
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE
|
||||||
|
|
||||||
@@ -64,11 +65,12 @@ def parse_code_from_string(input_string):
|
|||||||
|
|
||||||
|
|
||||||
class GPT4v_Agent:
|
class GPT4v_Agent:
|
||||||
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"):
|
def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13", add_a11y_tree=False):
|
||||||
self.instruction = instruction
|
self.instruction = instruction
|
||||||
self.model = model
|
self.model = model
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
|
self.add_a11y_tree = add_a11y_tree
|
||||||
|
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@@ -95,17 +97,34 @@ class GPT4v_Agent:
|
|||||||
Predict the next action(s) based on the current observation.
|
Predict the next action(s) based on the current observation.
|
||||||
"""
|
"""
|
||||||
base64_image = encode_image(obs["screenshot"])
|
base64_image = encode_image(obs["screenshot"])
|
||||||
|
accessibility_tree = obs["accessibility_tree"]
|
||||||
|
|
||||||
|
leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||||
|
filtered_nodes = filter_nodes(leaf_nodes)
|
||||||
|
|
||||||
|
linearized_accessibility_tree = "tag\ttext\tposition\tsize\n"
|
||||||
|
# Linearize the accessibility tree nodes into a table format
|
||||||
|
|
||||||
|
for node in filtered_nodes:
|
||||||
|
linearized_accessibility_tree += node.tag + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get(
|
||||||
|
'{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t"
|
||||||
|
linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n"
|
||||||
|
|
||||||
self.trajectory.append({
|
self.trajectory.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "What's the next step that you will do to help with the task?"
|
"text": "What's the next step that you will do to help with the task?" if not self.add_a11y_tree
|
||||||
|
else "And given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(linearized_accessibility_tree)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": f"data:image/jpeg;base64,{base64_image}"
|
"url": f"data:image/jpeg;base64,{base64_image}",
|
||||||
|
"detail": "high"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user