# todo: needs to be refactored import time from typing import Dict, List import google.generativeai as genai from mm_agents.accessibility_tree_wrap.heuristic_retrieve import find_leaf_nodes, filter_nodes from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string class GeminiPro_Agent: def __init__(self, api_key, instruction, model='gemini-pro', max_tokens=300, temperature=0.0, action_space="computer_13"): genai.configure(api_key=api_key) self.instruction = instruction self.model = genai.GenerativeModel(model) self.max_tokens = max_tokens self.temperature = temperature self.action_space = action_space self.trajectory = [ { "role": "system", "parts": [ { "computer_13": SYS_PROMPT_ACTION, "pyautogui": SYS_PROMPT_CODE }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) ] } ] def predict(self, obs: Dict) -> List: """ Predict the next action(s) based on the current observation. Only support single-round conversation, only fill-in the last desktop screenshot. """ accessibility_tree = obs["accessibility_tree"] leaf_nodes = find_leaf_nodes(accessibility_tree) filtered_nodes = filter_nodes(leaf_nodes) linearized_accessibility_tree = "tag\ttext\tposition\tsize\n" # Linearize the accessibility tree nodes into a table format for node in filtered_nodes: linearized_accessibility_tree += node.tag + "\t" linearized_accessibility_tree += node.attrib.get('name') + "\t" linearized_accessibility_tree += node.attrib.get( '{uri:deskat:component.at-spi.gnome.org}screencoord') + "\t" linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size') + "\n" self.trajectory.append({ "role": "user", "parts": [ "Given the XML format of accessibility tree (convert and formatted into table) as below:\n{}\nWhat's the next step that you will do to help with the task?".format( linearized_accessibility_tree)] }) # todo: Remove this step once the Gemini supports multi-round conversation all_message_str = "" for i in range(len(self.trajectory)): if i == 0: all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n" elif i % 2 == 1: all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n" else: all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n" all_message_str += all_message_template.format(self.trajectory[i]["parts"][0]) print("All message: >>>>>>>>>>>>>>>> ") print( all_message_str ) message_for_gemini = { "role": "user", "parts": [all_message_str] } traj_to_show = [] for i in range(len(self.trajectory)): traj_to_show.append(self.trajectory[i]["parts"][0]) if len(self.trajectory[i]["parts"]) > 1: traj_to_show.append("screenshot_obs") print("Trajectory:", traj_to_show) while True: try: response = self.model.generate_content( message_for_gemini, generation_config={ "max_output_tokens": self.max_tokens, "temperature": self.temperature } ) break except: print("Failed to generate response, retrying...") time.sleep(5) pass try: response_text = response.text except: return [] try: actions = self.parse_actions(response_text) except: print("Failed to parse action from response:", response_text) actions = [] return actions def parse_actions(self, response: str): # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) else: raise ValueError("Invalid action space: " + self.action_space) # add action into the trajectory self.trajectory.append({ "role": "assistant", "parts": [response] }) return actions