[Feature] Initialize and Implement Aguvis Evaluation on OSWorld (#98)

* Initialize Aguvis eval on OSWorld

* Debug

* Debug

* v1, internal version

* Add experiments script

* Fix minor bugs

* Update new endpoint

* Update ip

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Fix model name

* Fix docker close issues; update prompting

* Fix missed

* Fix the default port to avoid crashing on examples like '_update_browse_history_setup'

* Fix server and chromium ports in setup

* Revert and add missed dependency

* Add VLC port for docker

* Update

* Clean

---------

Co-authored-by: Tianbao Xie <tianbaoxie@U-492FC39R-0217.local>
Co-authored-by: FredWuCZ <fredwucz@outlook.com>
This commit is contained in:
Tianbao Xie
2024-11-11 12:36:16 +08:00
committed by GitHub
parent b35dc40ff4
commit 20442244fa
7 changed files with 910 additions and 10 deletions

View File

@@ -28,10 +28,11 @@ FILE_PATH = os.path.dirname(os.path.abspath(__file__))
class SetupController:
def __init__(self, vm_ip: str, server_port: int, chromium_port: int, cache_dir: str):
def __init__(self, vm_ip: str, server_port: int = 5000, chromium_port: int = 9222, vlc_port: int = 8080, cache_dir: str = "cache"):
self.vm_ip: str = vm_ip
self.server_port: int = server_port
self.chromium_port: int = chromium_port
self.vlc_port: int = vlc_port
self.http_server: str = f"http://{vm_ip}:{server_port}"
self.http_server_setup_root: str = f"http://{vm_ip}:{server_port}/setup"
self.cache_dir: str = cache_dir
@@ -532,7 +533,7 @@ class SetupController:
"""
host = self.vm_ip
port = 9222 # fixme: this port is hard-coded, need to be changed from config file
port = self.chromium_port
remote_debugging_url = f"http://{host}:{port}"
with sync_playwright() as p:
@@ -643,7 +644,7 @@ class SetupController:
logger.info('Fake browsing history added successfully.')
controller = PythonController(self.vm_ip)
controller = PythonController(self.vm_ip, self.server_port)
# get the path of the history file according to the platform
os_type = controller.get_vm_platform()

View File

@@ -59,6 +59,7 @@ class DesktopEnv(gym.Env):
self.server_port = 5000
self.chromium_port = 9222
self.vnc_port = 8006
self.vlc_port = 8080
self.manager, self.provider = create_vm_manager_and_provider(provider_name, region)
self.os_type = os_type
@@ -104,8 +105,9 @@ class DesktopEnv(gym.Env):
self.server_port = int(vm_ip_ports[1])
self.chromium_port = int(vm_ip_ports[2])
self.vnc_port = int(vm_ip_ports[3])
self.vlc_port = int(vm_ip_ports[4])
self.controller = PythonController(vm_ip=self.vm_ip, server_port=self.server_port)
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, cache_dir=self.cache_dir_base)
self.setup_controller = SetupController(vm_ip=self.vm_ip, server_port=self.server_port, chromium_port=self.chromium_port, vlc_port=self.vlc_port, cache_dir=self.cache_dir_base)
def _revert_to_snapshot(self):
# Revert to certain snapshot of the virtual machine, and refresh the path to vm and ip of vm

View File

@@ -14,7 +14,7 @@ def get_vlc_playing_info(env, config: Dict[str, str]):
"""
host = env.vm_ip
port = 8080
port = env.vlc_port
password = 'password'
_path = os.path.join(env.cache_dir, config["dest"])

View File

@@ -28,6 +28,8 @@ class DockerProvider(Provider):
self.server_port = None
self.vnc_port = None
self.chromium_port = None
self.vlc_port = None
self.container = None
self.environment = {"DISK_SIZE": "32G", "RAM_SIZE": "4G", "CPU_CORES": "4"} # Modify if needed
temp_dir = Path(os.getenv('TEMP') if platform.system() == 'Windows' else '/tmp')
@@ -92,6 +94,7 @@ class DockerProvider(Provider):
self.vnc_port = self._get_available_port(8006)
self.server_port = self._get_available_port(5000)
self.chromium_port = self._get_available_port(9222)
self.vlc_port = self._get_available_port(8080)
# Start container while still holding the lock
self.container = self.client.containers.run(
@@ -108,13 +111,14 @@ class DockerProvider(Provider):
ports={
8006: self.vnc_port,
5000: self.server_port,
9222: self.chromium_port
9222: self.chromium_port,
8080: self.vlc_port
},
detach=True
)
logger.info(f"Started container with ports - VNC: {self.vnc_port}, "
f"Server: {self.server_port}, Chrome: {self.chromium_port}")
f"Server: {self.server_port}, Chrome: {self.chromium_port}, VLC: {self.vlc_port}")
# Wait for VM to be ready
self._wait_for_vm_ready()
@@ -130,15 +134,15 @@ class DockerProvider(Provider):
raise e
def get_ip_address(self, path_to_vm: str) -> str:
if not all([self.server_port, self.chromium_port, self.vnc_port]):
if not all([self.server_port, self.chromium_port, self.vnc_port, self.vlc_port]):
raise RuntimeError("VM not started - ports not allocated")
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}"
return f"localhost:{self.server_port}:{self.chromium_port}:{self.vnc_port}:{self.vlc_port}"
def save_state(self, path_to_vm: str, snapshot_name: str):
raise NotImplementedError("Snapshots not available for Docker provider")
def revert_to_snapshot(self, path_to_vm: str, snapshot_name: str):
pass
self.stop_emulator(path_to_vm)
def stop_emulator(self, path_to_vm: str):
if self.container:
@@ -154,3 +158,4 @@ class DockerProvider(Provider):
self.server_port = None
self.vnc_port = None
self.chromium_port = None
self.vlc_port = None

530
mm_agents/aguvis_agent.py Normal file
View File

@@ -0,0 +1,530 @@
import base64
import json
import logging
import os
import re
import tempfile
import time
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from requests.exceptions import SSLError
logger = logging.getLogger("desktopenv.aguvis_agent")
# Function to encode the image
def encode_image(image_content):
return base64.b64encode(image_content).decode('utf-8')
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
def save_to_tmp_img_file(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
image.save(tmp_img_path)
return tmp_img_path
# TODO: hardcoded screen size, need to be fixed
SCREEN_LOGIC_SIZE = (1280, 800)
wait_func = {"name": "WAIT", "description": "wait for a moment"}
done_func = {"name": "DONE", "description": "done with the task"}
fail_func = {"name": "FAIL", "description": "fail to complete the task"}
SYS_PROMPT = f"""You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.
"""
# TODO: let GPT not to predict non-atomic actions,
PLANNER_OUTPUT_FORMAT_SYS_PROMPT = """Your response should be formatted as follows:
Thought: *Describe your understanding of the current situation and consider what you need to do next.*
Action: *State the specific action you have decided to perform, described in natural language.*
**Note:** Please **do not** predict non-atomic actions. For example, for multi-step operations like "click then select the date," only predict the first atomic action (e.g., "click") at this time, and leave subsequent steps (like click for selecting the date) for the next planning phase.
**Example:**
Thought: To proceed with booking a hotel, I must first specify the check-in and check-out dates for the stay. Since the objective is to book a three-night stay starting from the 1st of June, I need to input these dates into the form to find available accommodations.
Action: Click on the "Choose date" button in the Check-in field to start selecting the stay dates.
Addtionally, you can use the following functions:
- {json.dumps(wait_func)}
- {json.dumps(done_func)}
- {json.dumps(fail_func)}
**Example 1:**
Thought: I need to wait for a moment before proceeding.
Action: WAIT
**Example 2:**
Thought: I have completed the task.
Action: DONE
"""
INSTRUCTION_PROMPT = """Please generate the next move according to the UI screenshot, instruction and previous actions.
Instruction: {instruction}
"""
ACTION_PROMPT = """Previous actions:
"""
def _pyautogui_code_to_absolute_coordinates(pyautogui_code_relative_coordinates, logical_screen_size=SCREEN_LOGIC_SIZE):
"""
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
"""
import re
import ast
width, height = logical_screen_size
pattern = r'(pyautogui\.\w+\([^\)]*\))'
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
new_code = pyautogui_code_relative_coordinates
for full_call in matches:
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
if not func_match:
continue
func_name = func_match.group(1)
args_str = func_match.group(2)
try:
parsed = ast.parse(f"func({args_str})").body[0].value
parsed_args = parsed.args
parsed_keywords = parsed.keywords
except SyntaxError:
continue
function_parameters = {
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'],
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'],
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
}
func_base_name = func_name.split('.')[-1]
param_names = function_parameters.get(func_base_name, [])
args = {}
for idx, arg in enumerate(parsed_args):
if idx < len(param_names):
param_name = param_names[idx]
arg_value = ast.literal_eval(arg)
args[param_name] = arg_value
for kw in parsed_keywords:
param_name = kw.arg
arg_value = ast.literal_eval(kw.value)
args[param_name] = arg_value
updated = False
if 'x' in args:
try:
x_rel = float(args['x'])
x_abs = int(round(x_rel * width))
args['x'] = x_abs
updated = True
except ValueError:
pass
if 'y' in args:
try:
y_rel = float(args['y'])
y_abs = int(round(y_rel * height))
args['y'] = y_abs
updated = True
except ValueError:
pass
if 'xOffset' in args:
try:
x_rel = float(args['xOffset'])
x_abs = int(round(x_rel * width))
args['xOffset'] = x_abs
updated = True
except ValueError:
pass
if 'yOffset' in args:
try:
y_rel = float(args['yOffset'])
y_abs = int(round(y_rel * height))
args['yOffset'] = y_abs
updated = True
except ValueError:
pass
if updated:
reconstructed_args = []
for idx, param_name in enumerate(param_names):
if param_name in args:
arg_value = args[param_name]
if isinstance(arg_value, str):
arg_repr = f"'{arg_value}'"
else:
arg_repr = str(arg_value)
reconstructed_args.append(arg_repr)
else:
break
used_params = set(param_names[:len(reconstructed_args)])
for kw in parsed_keywords:
if kw.arg not in used_params:
arg_value = args[kw.arg]
if isinstance(arg_value, str):
arg_repr = f"{kw.arg}='{arg_value}'"
else:
arg_repr = f"{kw.arg}={arg_value}"
reconstructed_args.append(arg_repr)
new_args_str = ', '.join(reconstructed_args)
new_full_call = f"{func_name}({new_args_str})"
new_code = new_code.replace(full_call, new_full_call)
return new_code
def _parse(text, screen_logic_size=SCREEN_LOGIC_SIZE):
if text.lower().startswith("wait"):
return "WAIT", "WAIT"
elif text.lower().startswith("done"):
return "DONE", "DONE"
elif text.lower().startswith("fail"):
return "FAIL", "FAIL"
try:
lines = text.strip().split("\n")
lines = [line for line in lines if line.strip() != ""] # Remove empty lines
pyautogui_index = -1
for i, line in enumerate(lines):
if line.strip() == "assistantos" or line.strip().startswith("pyautogui"):
pyautogui_index = i
break
if pyautogui_index == -1:
print(f"Error: Could not parse response {text}")
return None, None # Return None or handle the error as needed
pyautogui_code_relative_coordinates = "\n".join(lines[pyautogui_index:])
# remove the assistantos prefix, ugly, fix later
pyautogui_code_relative_coordinates = pyautogui_code_relative_coordinates.replace("assistantos", "")
parsed_action = _pyautogui_code_to_absolute_coordinates(pyautogui_code_relative_coordinates, screen_logic_size)
return parsed_action
except Exception as e:
print(f"Error: Could not parse response {text}")
return None
def parse_planner_response(planner_response):
try:
# Split the response into lines for easier parsing
lines = planner_response.splitlines()
# Initialize variables to store thought and action
thought = None
action_description = None
# Iterate over each line to find the thought and action
for line in lines:
# Check if the line starts with 'Thought:'
if line.startswith("Thought:"):
# Extract the part after 'Thought: ' as the thought
thought = line[len("Thought: "):].strip()
# Check if the line starts with 'Action:'
elif line.startswith("Action:"):
# Extract the part after 'Action: ' as the action
action_description = line[len("Action: "):].strip()
# Return the thought and action as a dictionary
return thought, action_description
except Exception as e:
print(f"Error: Could not parse response {planner_response}")
return "", ""
class AguvisAgent:
def __init__(
self,
platform="ubuntu",
planner_model="gpt-4o",
executor_model="qwen-aguvis-7b",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot",
):
self.platform = platform
self.planner_model = planner_model
self.executor_model = executor_model
assert self.executor_model is not None, "Executor model cannot be None"
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
# Prepare the payload for the API call
messages = []
masks = None
self.observations.append(obs["screenshot"])
messages.append({
"role": "system",
"content": [
{
"type": "text",
"text": SYS_PROMPT
},
]
})
instruction_prompt = INSTRUCTION_PROMPT.format(instruction=instruction)
history_actions_prompt = ACTION_PROMPT
# thought, or so called action description
for i, action_description in enumerate(self.action_descriptions):
history_actions_prompt += f"Step {i+1}: {action_description}\n"
if len(history_actions_prompt) > 0:
instruction_prompt += "\n\n" + history_actions_prompt
base64_img = encode_image(obs["screenshot"])
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": instruction_prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_img}",
"detail": "high"
}
}
]
})
if self.planner_model is None:
# For now, we call the same model twice, one for planner and one for executor,
# This can be improved later when the inference stop token fixed
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": """<|recipient|>all\nAction: """
}
]
})
with open("messages_direct_executor.json", "w") as f:
f.write(json.dumps(messages, indent=4))
executor_response = self.call_llm({
"model": self.executor_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.executor_model)
logger.info("EXECUTOR RESPONSE: %s", executor_response)
pyautogui_action = _parse(executor_response)
thought, action_description = parse_planner_response("Action: " + executor_response)
self.thoughts.append(thought)
self.action_descriptions.append(action_description)
self.actions.append(pyautogui_action)
return executor_response, [pyautogui_action]
else:
# Planner stage
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": PLANNER_OUTPUT_FORMAT_SYS_PROMPT + "\nThought:"
}
]
})
planner_response = self.call_llm({
"model": self.planner_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.planner_model)
logger.info("PLANNER RESPONSE: %s", planner_response)
thought, action_description = parse_planner_response(planner_response)
self.thoughts.append(thought)
self.action_descriptions.append(action_description)
if "WAIT" in action_description:
self.actions.append("WAIT")
return planner_response, ["WAIT"]
elif "DONE" in action_description:
self.actions.append("DONE")
return planner_response, ["DONE"]
elif "FAIL" in action_description:
self.actions.append("FAIL")
return planner_response, ["FAIL"]
messages[1]["content"][0]["text"] = INSTRUCTION_PROMPT.format(instruction=action_description)
# pretend nothing happend with stronger planner model
messages[-1] = {
"role": "assistant",
"content": [
{
"type": "text",
# "text": f"""<|recipient|>all\nAction: {action_description}<|im_end|>\n<|im_start|>assistant<|recipient|>os"""
"text": f"""<|recipient|>os"""
}
]
}
with open("messages_executor.json", "w") as f:
f.write(json.dumps(messages, indent=4))
# Executor stage
executor_response = self.call_llm({
"model": self.executor_model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
}, self.executor_model)
logger.info("EXECUTOR RESPONSE: %s", executor_response)
pyautogui_action = _parse(executor_response)
self.actions.append(pyautogui_action)
return planner_response + "\n\n" + executor_response, [pyautogui_action]
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
# Groq exceptions
# todo: check
),
interval=30,
max_tries=10
)
def call_llm(self, payload, model):
if model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
# "Authorization": f"Bearer {os.environ['MIT_SPIDER_TOKEN']}"
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
# "http://47.88.8.18:8088/v1/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif "aguvis" in model:
headers = {
"Content-Type": "application/json",
}
logger.info("Generating content with Aguvis model: %s", model)
response = requests.post(
"http://101.132.136.195:7908/v1/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
def reset(self):
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
if __name__ == "__main__":
agent = AguvisAgent()
with open("screenshot.png", "rb") as f:
screenshot = f.read()
agent.predict("Add a new paper to my list", {"screenshot": screenshot})
# relative_code = """pyautogui.typewrite("Hello, world! I have a float number 0.172")
# pyautogui.click(0, 1, n_click=1)
# pyautogui.moveTo(0.5342, 0.5342)
# """
# absolute_code = _pyautogui_code_to_absolute_coordinates(relative_code, logical_screen_size=(1920, 1080))
# print(absolute_code)

View File

@@ -25,6 +25,7 @@ pypdf
PyGetWindow
rapidfuzz
pyacoustid
pygame
opencv-python
ImageHash
scikit-image

361
run_multienv_aguvis.py Normal file
View File

@@ -0,0 +1,361 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.aguvis_agent import AguvisAgent
# import wandb
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--planner_model", type=str, default="gpt-4o")
parser.add_argument("--executor_model", type=str, default="/mnt/chuzhe.hby/hf_ckpts/qwen-aguvis-7b")
parser.add_argument("--temperature", type=float, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
args = parser.parse_args()
return args
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: AguvisAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
"planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model),
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
# First, set up all environments
logger.info("Setting up all environments...")
envs = []
agents = []
for env_idx in range(args.num_envs):
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
agent = AguvisAgent(
planner_model=args.planner_model,
executor_model=args.executor_model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
)
agents.append(agent)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=agent.action_space,
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
provider_name = "docker"
)
envs.append(env)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
processes = []
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
p = Process(
target=run_env_tasks,
args=(env_idx, env, agent, env_tasks, args, shared_scores)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
"planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model),
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
"planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model),
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)