From 20442244fa75cece352441492813ed94f74a2ec9 Mon Sep 17 00:00:00 2001 From: Tianbao Xie <47296835+Timothyxxx@users.noreply.github.com> Date: Mon, 11 Nov 2024 12:36:16 +0800 Subject: [PATCH] [Feature] Initialize and Implement Aguvis Evaluation on OSWorld (#98) * Initialize Aguvis eval on OSWorld * Debug * Debug * v1, internal version * Add experiments script * Fix minor bugs * Update new endpoint * Update ip * Update * Update * Update * Update * Update * Update * Update * Update * Fix model name * Fix docker close issues; update prompting * Fix missed * Fix the default port to avoid crashing on examples like '_update_browse_history_setup' * Fix server and chromium ports in setup * Revert and add missed dependency * Add VLC port for docker * Update * Clean --------- Co-authored-by: Tianbao Xie Co-authored-by: FredWuCZ --- desktop_env/controllers/setup.py | 7 +- desktop_env/desktop_env.py | 4 +- desktop_env/evaluators/getters/vlc.py | 2 +- desktop_env/providers/docker/provider.py | 15 +- mm_agents/aguvis_agent.py | 530 +++++++++++++++++++++++ requirements.txt | 1 + run_multienv_aguvis.py | 361 +++++++++++++++ 7 files changed, 910 insertions(+), 10 deletions(-) create mode 100644 mm_agents/aguvis_agent.py create mode 100644 run_multienv_aguvis.py diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index 06695f5..ac1c372 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -28,10 +28,11 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__)) class SetupController: - def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str): + def __init__(self, vm_ip: str, server_port: int = 5000, chromium_port: int = 9222, vlc_port: int = 8080, cache_dir: str = "cache"): self.vm_ip: str = vm_ip self.server_port: int = server_port self.chromium_port: int = chromium_port + self.vlc_port: int = vlc_port self.http_server: str = f"http://{vm_ip}:{server_port}" self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup" self.cache_dir: str = cache_dir @@ -532,7 +533,7 @@ class SetupController: """ host = self.vm_ip - port = 9222 # fixme: this port is hard-coded, need to be changed from config file + port = self.chromium_port remote_debugging_url = f"http://{host}:{port}" with sync_playwright() as p: @@ -643,7 +644,7 @@ class SetupController: logger.info('Fake browsing history added successfully.') - controller = PythonController(self.vm_ip) + controller = PythonController(self.vm_ip, self.server_port) # get the path of the history file according to the platform os_type = controller.get_vm_platform() diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index cf05d03..301fdf4 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -59,6 +59,7 @@ class DesktopEnv(gym.Env): self.server_port = 5000 self.chromium_port = 9222 self.vnc_port = 8006 + self.vlc_port = 8080 self.manager, self.provider = create_vm_manager_and_provider(provider_name, region) self.os_type = os_type @@ -104,8 +105,9 @@ class DesktopEnv(gym.Env): self.server_port = int(vm_ip_ports[1]) self.chromium_port = int(vm_ip_ports[2]) self.vnc_port = int(vm_ip_ports[3]) + self.vlc_port = int(vm_ip_ports[4]) self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port) - self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base) + self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base) def _revert_to_snapshot(self): # Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm diff --git a/desktop_env/evaluators/getters/vlc.py b/desktop_env/evaluators/getters/vlc.py index ef0adfb..911a023 100644 --- a/desktop_env/evaluators/getters/vlc.py +++ b/desktop_env/evaluators/getters/vlc.py @@ -14,7 +14,7 @@ def get_vlc_playing_info(env, config: Dict[str, str]): """ host = env.vm_ip - port = 8080 + port = env.vlc_port password = 'password' _path = os.path.join(env.cache_dir, config["dest"]) diff --git a/desktop_env/providers/docker/provider.py b/desktop_env/providers/docker/provider.py index faed091..5c58b27 100644 --- a/desktop_env/providers/docker/provider.py +++ b/desktop_env/providers/docker/provider.py @@ -28,6 +28,8 @@ class DockerProvider(Provider): self.server_port = None self.vnc_port = None self.chromium_port = None + self.vlc_port = None + self.container = None self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp') @@ -92,6 +94,7 @@ class DockerProvider(Provider): self.vnc_port = self._get_available_port(8006) self.server_port = self._get_available_port(5000) self.chromium_port = self._get_available_port(9222) + self.vlc_port = self._get_available_port(8080) # Start container while still holding the lock self.container = self.client.containers.run( @@ -108,13 +111,14 @@ class DockerProvider(Provider): ports={ 8006: self.vnc_port, 5000: self.server_port, - 9222: self.chromium_port + 9222: self.chromium_port, + 8080: self.vlc_port }, detach=True ) logger.info(f"Started container with ports - VNC: {self.vnc_port}, " - f"Server: {self.server_port}, Chrome: {self.chromium_port}") + f"Server: {self.server_port}, Chrome: {self.chromium_port}, VLC: {self.vlc_port}") # Wait for VM to be ready self._wait_for_vm_ready() @@ -130,15 +134,15 @@ class DockerProvider(Provider): raise e def get_ip_address(self, path_to_vm: str) -> str: - if not all([self.server_port, self.chromium_port, self.vnc_port]): + if not all([self.server_port, self.chromium_port, self.vnc_port, self.vlc_port]): raise RuntimeError("VM not started - ports not allocated") - return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}" + return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}:{self.vlc_port}" def save_state(self, path_to_vm: str, snapshot_name: str): raise NotImplementedError("Snapshots not available for Docker provider") def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str): - pass + self.stop_emulator(path_to_vm) def stop_emulator(self, path_to_vm: str): if self.container: @@ -154,3 +158,4 @@ class DockerProvider(Provider): self.server_port = None self.vnc_port = None self.chromium_port = None + self.vlc_port = None diff --git a/mm_agents/aguvis_agent.py b/mm_agents/aguvis_agent.py new file mode 100644 index 0000000..a1de1fd --- /dev/null +++ b/mm_agents/aguvis_agent.py @@ -0,0 +1,530 @@ +import base64 +import json +import logging +import os +import re +import tempfile +import time +from http import HTTPStatus +from io import BytesIO +from typing import Dict, List + +import backoff +import openai +import requests +from PIL import Image +from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest +from requests.exceptions import SSLError + +logger = logging.getLogger("desktopenv.aguvis_agent") + +# Function to encode the image +def encode_image(image_content): + return base64.b64encode(image_content).decode('utf-8') + + +def encoded_img_to_pil_img(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + return image + + +def save_to_tmp_img_file(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png") + image.save(tmp_img_path) + + return tmp_img_path + +# TODO: hardcoded screen size, need to be fixed +SCREEN_LOGIC_SIZE = (1280, 800) + +wait_func = {"name": "WAIT", "description": "wait for a moment"} +done_func = {"name": "DONE", "description": "done with the task"} +fail_func = {"name": "FAIL", "description": "fail to complete the task"} + +SYS_PROMPT = f"""You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task. +""" + +# TODO: let GPT not to predict non-atomic actions, +PLANNER_OUTPUT_FORMAT_SYS_PROMPT = """Your response should be formatted as follows: +Thought: *Describe your understanding of the current situation and consider what you need to do next.* +Action: *State the specific action you have decided to perform, described in natural language.* + +**Note:** Please **do not** predict non-atomic actions. For example, for multi-step operations like "click then select the date," only predict the first atomic action (e.g., "click") at this time, and leave subsequent steps (like click for selecting the date) for the next planning phase. + +**Example:** +Thought: To proceed with booking a hotel, I must first specify the check-in and check-out dates for the stay. Since the objective is to book a three-night stay starting from the 1st of June, I need to input these dates into the form to find available accommodations. +Action: Click on the "Choose date" button in the Check-in field to start selecting the stay dates. + +Addtionally, you can use the following functions: +- {json.dumps(wait_func)} +- {json.dumps(done_func)} +- {json.dumps(fail_func)} + +**Example 1:** +Thought: I need to wait for a moment before proceeding. +Action: WAIT + +**Example 2:** +Thought: I have completed the task. +Action: DONE +""" + +INSTRUCTION_PROMPT = """Please generate the next move according to the UI screenshot, instruction and previous actions. + +Instruction: {instruction} +""" + +ACTION_PROMPT = """Previous actions: +""" + +def _pyautogui_code_to_absolute_coordinates(pyautogui_code_relative_coordinates, logical_screen_size=SCREEN_LOGIC_SIZE): + """ + Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size. + """ + import re + import ast + + width, height = logical_screen_size + + pattern = r'(pyautogui\.\w+\([^\)]*\))' + + matches = re.findall(pattern, pyautogui_code_relative_coordinates) + + new_code = pyautogui_code_relative_coordinates + + for full_call in matches: + func_name_pattern = r'(pyautogui\.\w+)\((.*)\)' + func_match = re.match(func_name_pattern, full_call, re.DOTALL) + if not func_match: + continue + + func_name = func_match.group(1) + args_str = func_match.group(2) + + try: + parsed = ast.parse(f"func({args_str})").body[0].value + parsed_args = parsed.args + parsed_keywords = parsed.keywords + except SyntaxError: + continue + + function_parameters = { + 'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'], + 'moveTo': ['x', 'y', 'duration', 'tween', 'pause'], + 'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'], + 'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'], + 'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'], + 'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'], + } + + func_base_name = func_name.split('.')[-1] + + param_names = function_parameters.get(func_base_name, []) + + args = {} + for idx, arg in enumerate(parsed_args): + if idx < len(param_names): + param_name = param_names[idx] + arg_value = ast.literal_eval(arg) + args[param_name] = arg_value + + for kw in parsed_keywords: + param_name = kw.arg + arg_value = ast.literal_eval(kw.value) + args[param_name] = arg_value + + updated = False + if 'x' in args: + try: + x_rel = float(args['x']) + x_abs = int(round(x_rel * width)) + args['x'] = x_abs + updated = True + except ValueError: + pass + if 'y' in args: + try: + y_rel = float(args['y']) + y_abs = int(round(y_rel * height)) + args['y'] = y_abs + updated = True + except ValueError: + pass + if 'xOffset' in args: + try: + x_rel = float(args['xOffset']) + x_abs = int(round(x_rel * width)) + args['xOffset'] = x_abs + updated = True + except ValueError: + pass + if 'yOffset' in args: + try: + y_rel = float(args['yOffset']) + y_abs = int(round(y_rel * height)) + args['yOffset'] = y_abs + updated = True + except ValueError: + pass + + if updated: + reconstructed_args = [] + for idx, param_name in enumerate(param_names): + if param_name in args: + arg_value = args[param_name] + if isinstance(arg_value, str): + arg_repr = f"'{arg_value}'" + else: + arg_repr = str(arg_value) + reconstructed_args.append(arg_repr) + else: + break + + used_params = set(param_names[:len(reconstructed_args)]) + for kw in parsed_keywords: + if kw.arg not in used_params: + arg_value = args[kw.arg] + if isinstance(arg_value, str): + arg_repr = f"{kw.arg}='{arg_value}'" + else: + arg_repr = f"{kw.arg}={arg_value}" + reconstructed_args.append(arg_repr) + + new_args_str = ', '.join(reconstructed_args) + new_full_call = f"{func_name}({new_args_str})" + new_code = new_code.replace(full_call, new_full_call) + + return new_code + +def _parse(text, screen_logic_size=SCREEN_LOGIC_SIZE): + if text.lower().startswith("wait"): + return "WAIT", "WAIT" + elif text.lower().startswith("done"): + return "DONE", "DONE" + elif text.lower().startswith("fail"): + return "FAIL", "FAIL" + + try: + lines = text.strip().split("\n") + lines = [line for line in lines if line.strip() != ""] # Remove empty lines + + pyautogui_index = -1 + + for i, line in enumerate(lines): + if line.strip() == "assistantos" or line.strip().startswith("pyautogui"): + pyautogui_index = i + break + + if pyautogui_index == -1: + print(f"Error: Could not parse response {text}") + return None, None # Return None or handle the error as needed + + pyautogui_code_relative_coordinates = "\n".join(lines[pyautogui_index:]) + # remove the assistantos prefix, ugly, fix later + pyautogui_code_relative_coordinates = pyautogui_code_relative_coordinates.replace("assistantos", "") + parsed_action = _pyautogui_code_to_absolute_coordinates(pyautogui_code_relative_coordinates, screen_logic_size) + return parsed_action + except Exception as e: + print(f"Error: Could not parse response {text}") + return None + + +def parse_planner_response(planner_response): + try: + # Split the response into lines for easier parsing + lines = planner_response.splitlines() + + # Initialize variables to store thought and action + thought = None + action_description = None + + # Iterate over each line to find the thought and action + for line in lines: + # Check if the line starts with 'Thought:' + if line.startswith("Thought:"): + # Extract the part after 'Thought: ' as the thought + thought = line[len("Thought: "):].strip() + + # Check if the line starts with 'Action:' + elif line.startswith("Action:"): + # Extract the part after 'Action: ' as the action + action_description = line[len("Action: "):].strip() + + # Return the thought and action as a dictionary + return thought, action_description + except Exception as e: + print(f"Error: Could not parse response {planner_response}") + return "", "" + +class AguvisAgent: + def __init__( + self, + platform="ubuntu", + planner_model="gpt-4o", + executor_model="qwen-aguvis-7b", + max_tokens=1500, + top_p=0.9, + temperature=0.5, + action_space="pyautogui", + observation_type="screenshot", + ): + self.platform = platform + self.planner_model = planner_model + self.executor_model = executor_model + assert self.executor_model is not None, "Executor model cannot be None" + self.max_tokens = max_tokens + self.top_p = top_p + self.temperature = temperature + self.action_space = action_space + self.observation_type = observation_type + assert action_space in ["pyautogui"], "Invalid action space" + assert observation_type in ["screenshot"], "Invalid observation type" + self.thoughts = [] + self.actions = [] + self.observations = [] + + def predict(self, instruction: str, obs: Dict) -> List: + """ + Predict the next action(s) based on the current observation. + """ + + # Prepare the payload for the API call + messages = [] + masks = None + self.observations.append(obs["screenshot"]) + + messages.append({ + "role": "system", + "content": [ + { + "type": "text", + "text": SYS_PROMPT + }, + ] + }) + + instruction_prompt = INSTRUCTION_PROMPT.format(instruction=instruction) + history_actions_prompt = ACTION_PROMPT + + # thought, or so called action description + for i, action_description in enumerate(self.action_descriptions): + history_actions_prompt += f"Step {i+1}: {action_description}\n" + + if len(history_actions_prompt) > 0: + instruction_prompt += "\n\n" + history_actions_prompt + + base64_img = encode_image(obs["screenshot"]) + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": instruction_prompt + }, + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{base64_img}", + "detail": "high" + } + } + ] + }) + + if self.planner_model is None: + # For now, we call the same model twice, one for planner and one for executor, + # This can be improved later when the inference stop token fixed + messages.append({ + "role": "assistant", + "content": [ + { + "type": "text", + "text": """<|recipient|>all\nAction: """ + } + ] + }) + + with open("messages_direct_executor.json", "w") as f: + f.write(json.dumps(messages, indent=4)) + + executor_response = self.call_llm({ + "model": self.executor_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature + }, self.executor_model) + + logger.info("EXECUTOR RESPONSE: %s", executor_response) + + pyautogui_action = _parse(executor_response) + + thought, action_description = parse_planner_response("Action: " + executor_response) + + self.thoughts.append(thought) + self.action_descriptions.append(action_description) + self.actions.append(pyautogui_action) + + return executor_response, [pyautogui_action] + + else: + # Planner stage + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": PLANNER_OUTPUT_FORMAT_SYS_PROMPT + "\nThought:" + } + ] + }) + + planner_response = self.call_llm({ + "model": self.planner_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature + }, self.planner_model) + + logger.info("PLANNER RESPONSE: %s", planner_response) + thought, action_description = parse_planner_response(planner_response) + self.thoughts.append(thought) + self.action_descriptions.append(action_description) + + if "WAIT" in action_description: + self.actions.append("WAIT") + return planner_response, ["WAIT"] + elif "DONE" in action_description: + self.actions.append("DONE") + return planner_response, ["DONE"] + elif "FAIL" in action_description: + self.actions.append("FAIL") + return planner_response, ["FAIL"] + + messages[1]["content"][0]["text"] = INSTRUCTION_PROMPT.format(instruction=action_description) + + # pretend nothing happend with stronger planner model + messages[-1] = { + "role": "assistant", + "content": [ + { + "type": "text", + # "text": f"""<|recipient|>all\nAction: {action_description}<|im_end|>\n<|im_start|>assistant<|recipient|>os""" + "text": f"""<|recipient|>os""" + } + ] + } + + with open("messages_executor.json", "w") as f: + f.write(json.dumps(messages, indent=4)) + + # Executor stage + executor_response = self.call_llm({ + "model": self.executor_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature + }, self.executor_model) + + logger.info("EXECUTOR RESPONSE: %s", executor_response) + + pyautogui_action = _parse(executor_response) + self.actions.append(pyautogui_action) + + return planner_response + "\n\n" + executor_response, [pyautogui_action] + + @backoff.on_exception( + backoff.constant, + # here you should add more model exceptions as you want, + # but you are forbidden to add "Exception", that is, a common type of exception + # because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit + ( + # General exceptions + SSLError, + + # OpenAI exceptions + openai.RateLimitError, + openai.BadRequestError, + openai.InternalServerError, + + # Google exceptions + InvalidArgument, + ResourceExhausted, + InternalServerError, + BadRequest, + + # Groq exceptions + # todo: check + ), + interval=30, + max_tries=10 + ) + def call_llm(self, payload, model): + + if model.startswith("gpt"): + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" + # "Authorization": f"Bearer {os.environ['MIT_SPIDER_TOKEN']}" + } + logger.info("Generating content with GPT model: %s", model) + response = requests.post( + "https://api.openai.com/v1/chat/completions", + # "http://47.88.8.18:8088/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code != 200: + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['choices'][0]['message']['content'] + + elif "aguvis" in model: + headers = { + "Content-Type": "application/json", + } + logger.info("Generating content with Aguvis model: %s", model) + response = requests.post( + "http://101.132.136.195:7908/v1/chat/completions", + headers=headers, + json=payload + ) + + if response.status_code != 200: + logger.error("Failed to call LLM: " + response.text) + time.sleep(5) + return "" + else: + return response.json()['choices'][0]['message']['content'] + + + def reset(self): + self.thoughts = [] + self.action_descriptions = [] + self.actions = [] + self.observations = [] + +if __name__ == "__main__": + agent = AguvisAgent() + with open("screenshot.png", "rb") as f: + screenshot = f.read() + agent.predict("Add a new paper to my list", {"screenshot": screenshot}) +# relative_code = """pyautogui.typewrite("Hello, world! I have a float number 0.172") +# pyautogui.click(0, 1, n_click=1) +# pyautogui.moveTo(0.5342, 0.5342) +# """ +# absolute_code = _pyautogui_code_to_absolute_coordinates(relative_code, logical_screen_size=(1920, 1080)) +# print(absolute_code) + diff --git a/requirements.txt b/requirements.txt index 4e0e167..427cf96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,6 +25,7 @@ pypdf PyGetWindow rapidfuzz pyacoustid +pygame opencv-python ImageHash scikit-image diff --git a/run_multienv_aguvis.py b/run_multienv_aguvis.py new file mode 100644 index 0000000..84e5b50 --- /dev/null +++ b/run_multienv_aguvis.py @@ -0,0 +1,361 @@ +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" + +import argparse +import datetime +import json +import logging +import os +import sys +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.aguvis_agent import AguvisAgent + +# import wandb + + +# Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--planner_model", type=str, default="gpt-4o") + parser.add_argument("--executor_model", type=str, default="/mnt/chuzhe.hby/hf_ckpts/qwen-aguvis-7b") + parser.add_argument("--temperature", type=float, default=0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + + args = parser.parse_args() + return args + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + + +def run_env_tasks(env_idx: int, env: DesktopEnv, agent: AguvisAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model), + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) + f.write("\n") + + env.close() + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + # First, set up all environments + logger.info("Setting up all environments...") + envs = [] + agents = [] + + for env_idx in range(args.num_envs): + logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") + + agent = AguvisAgent( + planner_model=args.planner_model, + executor_model=args.executor_model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + observation_type=args.observation_type, + ) + agents.append(agent) + + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=agent.action_space, + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type + in ["a11y_tree", "screenshot_a11y_tree", "som"], + provider_name = "docker" + ) + envs.append(env) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + p = Process( + target=run_env_tasks, + args=(env_idx, env, agent, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model), + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model), + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list)