From 8efa6929511471ebd05e69b5e81c2ab520c442b5 Mon Sep 17 00:00:00 2001 From: Timothyxxx <384084775@qq.com> Date: Tue, 16 Jan 2024 11:58:23 +0800 Subject: [PATCH] Add raw accessibility-tree based prompting method (but the tokens are too large); Minor fix some small bugs --- desktop_env/envs/desktop_env.py | 5 +- desktop_env/evaluators/metrics/general.py | 2 +- experiment.py | 2 +- mm_agents/gemini_agent_text.py | 0 mm_agents/gemini_pro_agent.py | 110 +++++++++++++ ...ni_agent.py => gemini_pro_vision_agent.py} | 4 +- mm_agents/gpt_4_agent.py | 150 ++++++++++++++++++ mm_agents/gpt_4v_agent.py | 2 + mm_agents/gpt_4v_agent_text.py | 0 mm_agents/gpt_4v_prompt_code.py | 1 + 10 files changed, 272 insertions(+), 4 deletions(-) delete mode 100644 mm_agents/gemini_agent_text.py create mode 100644 mm_agents/gemini_pro_agent.py rename mm_agents/{gemini_agent.py => gemini_pro_vision_agent.py} (96%) create mode 100644 mm_agents/gpt_4_agent.py delete mode 100644 mm_agents/gpt_4v_agent_text.py diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 8a79d72..786ed72 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -204,7 +204,10 @@ class DesktopEnv(gym.Env): time.sleep(5) logger.info("Environment setup complete.") - observation = {"screenshot": self._get_obs()} + observation = { + "screenshot": self._get_obs(), + "accessibility_tree": self.controller.get_accessibility_tree(), + } return observation def step(self, action, pause=0.5): diff --git a/desktop_env/evaluators/metrics/general.py b/desktop_env/evaluators/metrics/general.py index b0433c3..6246861 100644 --- a/desktop_env/evaluators/metrics/general.py +++ b/desktop_env/evaluators/metrics/general.py @@ -4,7 +4,7 @@ import functools import operator import re from numbers import Number -from typing import Callable, Any +from typing import Callable, Any, Union from typing import Dict, List, Pattern import lxml.etree diff --git a/experiment.py b/experiment.py index f6d1232..8e7f8b5 100644 --- a/experiment.py +++ b/experiment.py @@ -6,7 +6,7 @@ import sys from desktop_env.envs.desktop_env import DesktopEnv from mm_agents.gpt_4v_agent import GPT4v_Agent -from mm_agents.gemini_agent import GeminiPro_Agent +from mm_agents.gemini_pro_agent import GeminiPro_Agent # Logger Configs {{{ # logger = logging.getLogger() diff --git a/mm_agents/gemini_agent_text.py b/mm_agents/gemini_agent_text.py deleted file mode 100644 index e69de29..0000000 diff --git a/mm_agents/gemini_pro_agent.py b/mm_agents/gemini_pro_agent.py new file mode 100644 index 0000000..21a54e7 --- /dev/null +++ b/mm_agents/gemini_pro_agent.py @@ -0,0 +1,110 @@ +from typing import Dict, List + +import PIL.Image +import google.generativeai as genai + +from mm_agents.gpt_4v_agent import parse_actions_from_string, parse_code_from_string +from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION +from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE + + +class GeminiPro_Agent: + def __init__(self, api_key, instruction, model='gemini-pro', max_tokens=300, temperature=0.0, + action_space="computer_13"): + genai.configure(api_key=api_key) + self.instruction = instruction + self.model = genai.GenerativeModel(model) + self.max_tokens = max_tokens + self.temperature = temperature + self.action_space = action_space + + self.trajectory = [ + { + "role": "system", + "parts": [ + { + "computer_13": SYS_PROMPT_ACTION, + "pyautogui": SYS_PROMPT_CODE + }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) + ] + } + ] + + def predict(self, obs: Dict) -> List: + """ + Predict the next action(s) based on the current observation. + Only support single-round conversation, only fill-in the last desktop screenshot. + """ + accessibility_tree = obs["accessibility_tree"] + self.trajectory.append({ + "role": "user", + "parts": ["Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(accessibility_tree)] + }) + + # todo: Remove this step once the Gemini supports multi-round conversation + all_message_str = "" + for i in range(len(self.trajectory)): + if i == 0: + all_message_template = "<|im_start|>system\n{}\n<|im_end|>\n" + elif i % 2 == 1: + all_message_template = "<|im_start|>user\n{}\n<|im_end|>\n" + else: + all_message_template = "<|im_start|>assistant\n{}\n<|im_end|>\n" + + all_message_str += all_message_template.format(self.trajectory[i]["parts"][0]) + + print("All message: >>>>>>>>>>>>>>>> ") + print( + all_message_str + ) + + message_for_gemini = { + "role": "user", + "parts": [all_message_str] + } + + traj_to_show = [] + for i in range(len(self.trajectory)): + traj_to_show.append(self.trajectory[i]["parts"][0]) + if len(self.trajectory[i]["parts"]) > 1: + traj_to_show.append("screenshot_obs") + + print("Trajectory:", traj_to_show) + + response = self.model.generate_content( + message_for_gemini, + generation_config={ + "max_output_tokens": self.max_tokens, + "temperature": self.temperature + } + ) + + try: + response_text = response.text + except: + return [] + + try: + actions = self.parse_actions(response_text) + except: + print("Failed to parse action from response:", response_text) + actions = [] + + return actions + + def parse_actions(self, response: str): + # parse from the response + if self.action_space == "computer_13": + actions = parse_actions_from_string(response) + elif self.action_space == "pyautogui": + actions = parse_code_from_string(response) + else: + raise ValueError("Invalid action space: " + self.action_space) + + # add action into the trajectory + self.trajectory.append({ + "role": "assistant", + "parts": [response] + }) + + return actions diff --git a/mm_agents/gemini_agent.py b/mm_agents/gemini_pro_vision_agent.py similarity index 96% rename from mm_agents/gemini_agent.py rename to mm_agents/gemini_pro_vision_agent.py index 1593c9a..e4bb9d1 100644 --- a/mm_agents/gemini_agent.py +++ b/mm_agents/gemini_pro_vision_agent.py @@ -8,7 +8,7 @@ from mm_agents.gpt_4v_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION from mm_agents.gpt_4v_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE -class GeminiPro_Agent: +class GeminiProV_Agent: def __init__(self, api_key, instruction, model='gemini-pro-vision', max_tokens=300, temperature=0.0, action_space="computer_13"): genai.configure(api_key=api_key) @@ -93,6 +93,8 @@ class GeminiPro_Agent: actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) + else: + raise ValueError("Invalid action space: " + self.action_space) # add action into the trajectory self.trajectory.append({ diff --git a/mm_agents/gpt_4_agent.py b/mm_agents/gpt_4_agent.py new file mode 100644 index 0000000..8d2b82d --- /dev/null +++ b/mm_agents/gpt_4_agent.py @@ -0,0 +1,150 @@ +import base64 +import json +import re +from typing import Dict, List + +import requests + +from mm_agents.gpt_4_prompt_action import SYS_PROMPT as SYS_PROMPT_ACTION +from mm_agents.gpt_4_prompt_code import SYS_PROMPT as SYS_PROMPT_CODE + + +# Function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode('utf-8') + + +def parse_actions_from_string(input_string): + # Search for a JSON string within the input string + actions = [] + matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL) + if matches: + # Assuming there's only one match, parse the JSON string into a dictionary + try: + for match in matches: + action_dict = json.loads(match) + actions.append(action_dict) + return actions + except json.JSONDecodeError as e: + return f"Failed to parse JSON: {e}" + else: + matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL) + if matches: + # Assuming there's only one match, parse the JSON string into a dictionary + try: + for match in matches: + action_dict = json.loads(match) + actions.append(action_dict) + return actions + except json.JSONDecodeError as e: + return f"Failed to parse JSON: {e}" + else: + try: + action_dict = json.loads(input_string) + return [action_dict] + except json.JSONDecodeError as e: + raise ValueError("Invalid response format: " + input_string) + + +def parse_code_from_string(input_string): + # This regular expression will match both ```code``` and ```python code``` + # and capture the `code` part. It uses a non-greedy match for the content inside. + pattern = r"```(?:\w+\s+)?(.*?)```" + # Find all non-overlapping matches in the string + matches = re.findall(pattern, input_string, re.DOTALL) + + # The regex above captures the content inside the triple backticks. + # The `re.DOTALL` flag allows the dot `.` to match newline characters as well, + # so the code inside backticks can span multiple lines. + + # matches now contains all the captured code snippets + return matches + + +class GPT4_Agent: + def __init__(self, api_key, instruction, model="gpt-4-1106-preview", max_tokens=300, action_space="computer_13"): + self.instruction = instruction + self.model = model + self.max_tokens = max_tokens + self.action_space = action_space + + self.headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}" + } + + self.trajectory = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": { + "computer_13": SYS_PROMPT_ACTION, + "pyautogui": SYS_PROMPT_CODE + }[action_space] + "\nHere is the instruction for the task: {}".format(self.instruction) + }, + ] + } + ] + + def predict(self, obs: Dict) -> List: + """ + Predict the next action(s) based on the current observation. + """ + accessibility_tree = obs["accessibility_tree"] + self.trajectory.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "Given the XML format of accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(accessibility_tree) + } + ] + }) + + traj_to_show = [] + for i in range(len(self.trajectory)): + traj_to_show.append(self.trajectory[i]["content"][0]["text"]) + if len(self.trajectory[i]["content"]) > 1: + traj_to_show.append("screenshot_obs") + + print("Trajectory:", traj_to_show) + + payload = { + "model": self.model, + "messages": self.trajectory, + "max_tokens": self.max_tokens + } + response = requests.post("https://api.openai.com/v1/chat/completions", headers=self.headers, json=payload) + + try: + actions = self.parse_actions(response.json()['choices'][0]['message']['content']) + except: + print("Failed to parse action from response:", response.json()) + actions = None + + return actions + + def parse_actions(self, response: str): + # parse from the response + if self.action_space == "computer_13": + actions = parse_actions_from_string(response) + elif self.action_space == "pyautogui": + actions = parse_code_from_string(response) + else: + raise ValueError("Invalid action space: " + self.action_space) + + # add action into the trajectory + self.trajectory.append({ + "role": "assistant", + "content": [ + { + "type": "text", + "text": response + }, + ] + }) + + return actions diff --git a/mm_agents/gpt_4v_agent.py b/mm_agents/gpt_4v_agent.py index 81d128e..fdf6adb 100644 --- a/mm_agents/gpt_4v_agent.py +++ b/mm_agents/gpt_4v_agent.py @@ -139,6 +139,8 @@ class GPT4v_Agent: actions = parse_actions_from_string(response) elif self.action_space == "pyautogui": actions = parse_code_from_string(response) + else: + raise ValueError("Invalid action space: " + self.action_space) # add action into the trajectory self.trajectory.append({ diff --git a/mm_agents/gpt_4v_agent_text.py b/mm_agents/gpt_4v_agent_text.py deleted file mode 100644 index e69de29..0000000 diff --git a/mm_agents/gpt_4v_prompt_code.py b/mm_agents/gpt_4v_prompt_code.py index aa768e9..8dba73b 100644 --- a/mm_agents/gpt_4v_prompt_code.py +++ b/mm_agents/gpt_4v_prompt_code.py @@ -2,6 +2,7 @@ SYS_PROMPT = """ You are an agent which follow my instruction and perform desktop computer tasks as instructed. You have good knowledge of computer and good internet connection and assume your code will run on a computer for controlling the mouse and keyboard. For each step, you will get an observation of an image, which is the screenshot of the computer screen and you will predict the action of the computer based on the image. + You are required to use `pyautogui` to perform the action. Return one line or multiple lines of python code to perform the action each time, be time efficient. You ONLY need to return the code inside a code block, like this: