From 46bd3386dd23f25626190923891385f3ce566568 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Fri, 19 Jan 2024 20:34:47 +0800 Subject: [PATCH] Support input screenshot and a11y tree altogether --- mm_agents/gpt_4v_agent.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/gpt_4v_agent.py index d594b76..0dc3cb1 100644 --- a/mm_agents/gpt_4v_agent.py +++ b/mm_agents/gpt_4v_agent.py @@ -6,6 +6,7 @@ from typing import Dict, List import requests +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE @@ -64,11 +65,12 @@ def parse_code_from_string(input_string): class GPT4v_Agent: - def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13"): + def __init__(self, api_key, instruction, model="gpt-4-vision-preview", max_tokens=300, action_space="computer_13", add_a11y_tree=False): self.instruction = instruction self.model = model self.max_tokens = max_tokens self.action_space = action_space + self.add_a11y_tree = add_a11y_tree self.headers = { "Content-Type": "application/json", @@ -95,17 +97,34 @@ class GPT4v_Agent: Predict the next action(s) based on the current observation. """ base64_image = encode_image(obs["screenshot"]) + accessibility_tree = obs["accessibility_tree"] + + leaf_nodes = find_leaf_nodes(accessibility_tree) + filtered_nodes = filter_nodes(leaf_nodes) + + linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" + # Linearize the accessibility tree nodes into a table format + + for node in filtered_nodes: + linearized_accessibility_tree += node.tag + "\t" + linearized_accessibility_tree += node.attrib.get('name') + "\t" + linearized_accessibility_tree += node.attrib.get( + '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" + linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" + self.trajectory.append({ "role": "user", "content": [ { "type": "text", - "text": "What's the next step that you will do to help with the task?" + "text": "What's the next step that you will do to help with the task?" if not self.add_a11y_tree + else "And given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(linearized_accessibility_tree) }, { "type": "image_url", "image_url": { - "url": f"data:image/jpeg;base64,{base64_image}" + "url": f"data:image/jpeg;base64,{base64_image}", + "detail": "high" } } ]