from typing import Dict import PIL.Image import google.generativeai as genai from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE class GeminiPro_Agent: def __init__(self, api_key, model='gemini-pro-vision', max_tokens=300, action_space="computer_13"): genai.configure(api_key) self.model = genai.GenerativeModel(model) self.max_tokens = max_tokens self.action_space = action_space self.trajectory = [ { "role": "system", "parts": [ { "computer_13": SYS_PROMPT_ACTION, "pyautogui": SYS_PROMPT_CODE }[action_space] ] } ] def predict(self, obs: Dict): """ Predict the next action(s) based on the current observation. """ img = PIL.Image.open(obs["screenshot"]) self.trajectory.append({ "role": "user", "parts": ["To accomplish the task '{}' and given the current screenshot, what's the next step?".format( obs["instruction"]), img] }) traj_to_show = [] for i in range(len(self.trajectory)): traj_to_show.append(self.trajectory[i]["parts"][0]) if len(self.trajectory[i]["parts"]) > 1: traj_to_show.append("screenshot_obs") print("Trajectory:", traj_to_show) response = self.model.generate_content(self.trajectory, max_tokens=self.max_tokens) try: # fixme: change to fit the new response format from gemini pro actions = self.parse_actions(response.json()['choices'][0]['message']['content']) except: # todo: add error handling print("Failed to parse action from response:", response.json()['choices'][0]['message']['content']) actions = None return actions def parse_actions(self, response: str): # response example """ ```json { "action_type": "CLICK", "click_type": "RIGHT" } ``` """ # parse from the response if self.action_space == "computer_13": actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) # add action into the trajectory self.trajectory.append({ "role": "assistant", "parts": [response] }) return actions