Merge pull request #203 from yuanmengqi/main

add openai cua agent code
This commit is contained in:
Zilong Zhou
2025-05-31 20:48:01 +08:00
committed by GitHub
3 changed files with 1171 additions and 0 deletions

View File

@@ -62,3 +62,58 @@ def setup_logger(example, example_result_dir):
runtime_logger.setLevel(logging.DEBUG)
runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
return runtime_logger
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
runtime_logger = setup_logger(example, example_result_dir)
agent.reset(runtime_logger)
env.reset(task_config=example)
time.sleep(60) # Wait for the environment to be ready
obs = env._get_obs() # Get the initial observation
done = False
step_idx = 0
env.controller.start_recording()
while not done and step_idx < max_steps:
response, actions = agent.predict(
instruction,
obs
)
done = not response.get('state_correct', False)
for action in actions:
# Capture the timestamp before executing the action
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
logger.info("Step %d: %s", step_idx + 1, action)
obs, reward, done, info, step_info = agent.step(action)
if not done:
if not response.get('state_correct', False):
done = True
logger.info("Reward: %.2f", reward)
logger.info("Done: %s", done)
# Save screenshot and trajectory information
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
"wb") as _f:
_f.write(obs['screenshot'])
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(json.dumps({
"step_num": step_idx + 1,
"action_timestamp": action_timestamp,
"action": action,
"reward": reward,
"done": done,
"info": info,
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
}))
f.write("\n")
if done:
logger.info("The episode is done.")
break
step_idx += 1
result = env.evaluate()
logger.info("Result: %.2f", result)
scores.append(result)
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
f.write(f"{result}\n")
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))

View File

@@ -0,0 +1,759 @@
import base64
import json
import logging
import os
import re
import tempfile
import time
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
from groq import Groq
from requests.exceptions import SSLError
from typing import Any, Optional, Union, Tuple
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_OUT_TAG
logger = logging.getLogger("desktopenv.agent")
pure_text_settings = ['a11y_tree']
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
import ast
from typing import Dict, Any, Optional, Union
class Action:
"""Action class for the agent."""
def __init__(self, raw_action: Union[Dict, str], action_space: str):
"""Initialize the Action class.
Args:
raw_action: The raw action
action_space: The action space
"""
self._action_space = None
self._action = None
self.action_space = action_space
self.action = raw_action
@property
def action(self) -> str:
return self._action
@property
def action_space(self) -> str:
return self._action_space
@action_space.setter
def action_space(self, value: str):
"""
Set the action space for the agent.
Currently only supports 'pyautogui' as a valid action space.
Args:
value (str): The action space to set
Raises:
ValueError: If action_space is empty or invalid
"""
if not value:
raise ValueError("action_space is required")
if value not in ["pyautogui", "claude_computer_use"]:
raise ValueError(
"Invalid action space. Allowed spaces are: pyautogui")
self._action_space = value
@action.setter
def action(self, value: Optional[str]):
"""
Set the action for the agent.
For pyautogui action space, accepts special commands (WAIT, FAIL, DONE) or valid Python code.
For claude_computer_use action space, accepts a dict with keys "name", "input" and "id".
Args:
value (str | dict): The action to set
Raises:
ValueError: If action is empty or invalid
"""
if not value:
raise ValueError("action cannot be empty")
if self._action_space == "pyautogui":
self._action = value
# if value in ["WAIT", "FAIL", "DONE"]:
# self._action = value
# elif self._is_valid_python_code(value):
# self._action = value
# else:
# raise ValueError("Invalid action format for pyautogui")
elif self._action_space == "claude_computer_use":
self._action = value
# if self._is_valid_claude_computer_use_action(value):
# self._action = value
else:
raise ValueError(
f"Invalid action space: {self._action_space}, allowed spaces are: pyautogui, claude_computer_use")
def __str__(self) -> str:
"""Return a string representation of the Action instance.
Returns:
str: A string showing the action space and action value
"""
return f"Action(action_space='{self._action_space}', action='{self._action}')"
def get_action(self) -> Optional[str]:
"""Get the action.
Returns:
str: The action
"""
return self._action
def to_dict(self) -> Dict[str, Any]:
"""Convert the action to a dictionary.
Returns:
dict: The action as a dictionary
"""
return {"action_space": self._action_space, "action": self._action}
def _is_valid_python_code(self, code: str) -> bool:
"""
Validate if the given string is valid Python code syntax.
Args:
code (str): The code string to validate
Returns:
bool: True if code is valid Python syntax, False otherwise
"""
try:
ast.parse(code)
return True
except SyntaxError:
raise ValueError("Invalid Python code syntax")
def _is_valid_claude_computer_use_action(self, action: Dict[str, Any]) -> bool:
"""Validate if the given action is valid for the claude_computer_use action space.
Args:
action: The action to validate
Returns:
bool: True if action is valid, False otherwise
"""
if not isinstance(action, dict):
raise ValueError("Invalid action format for claude_computer_use")
if not (action.get("name") and action.get("input") and action.get("id")):
raise ValueError(
"Invalid action format for claude_computer_use, 'name', 'input' and 'id' are required")
return True
class Timer:
"""Context manager for timing code blocks."""
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
self.duration = time.time() - self.start
# Function to encode the image
def encode_image(image_content):
return base64.b64encode(image_content).decode('utf-8')
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
def save_to_tmp_img_file(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
image.save(tmp_img_path)
return tmp_img_path
class OpenAICUAAgent:
def __init__(
self,
env,
platform="ubuntu",
model="computer-use-preview",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="pyautogui",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=100,
a11y_tree_max_tokens=10000
):
self.env = env
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.cua_messages : List[Dict] = []
self.thoughts = []
self.actions = []
self.observations = []
self.tools = [{
"type": "computer_use_preview",
"display_width": 1920,
"display_height": 1080,
"environment": "linux" if platform == "ubuntu" else "windows"
}]
if observation_type == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + observation_type)
def _create_response(self, **kwargs: Any) -> Dict[str, Any]:
"""Create a response from the OpenAI API.
Args:
**kwargs: Additional arguments to pass to the API
Returns:
The API response as a dictionary
Raises:
requests.exceptions.RequestException: If the API request fails
"""
retry_count = 0
while retry_count < 3:
try:
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY_CUA"))
response = client.responses.create(
model=self.model,
input=self.cua_messages,
tools=self.tools,
reasoning={
"generate_summary": "concise",
},
truncation="auto",
)
logger.debug(f"Received successful response from OpenAI API")
logger.info(f"Response: {response}")
return response
except Exception as e:
logger.error(f"OpenAI API error: {str(e)}")
new_screenshot = self.env._get_obs()
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
retry_count += 1
time.sleep(1)
raise Exception("Failed to make OpenAI API call after 3 retries")
def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]:
"""Parse a response item from the OpenAI API.
Args:
item: The response item to parse
Returns:
The parsed item as either a string message or a dictionary containing action information,
or None if the item couldn't be parsed
"""
if item.type == "message":
if item.content is not None:
response = item.content[0] if isinstance(item.content, list) else item.content
response_type = response.type
response_text = response.text
logger.info(f"Received response text: {response_type} - {response_text}")
if response_type == "output_text":
return response_text
return None
return None
if item.type == "function_call":
return None
if item.type == "reasoning":
reasoning = item.summary
if isinstance(reasoning, list):
reasoning_item = reasoning[0]
reasoning_text = reasoning_item.text
reasoning_type = reasoning_item.type
if reasoning_type == "summary_text":
return reasoning_text
return None
return None
if item.type == "computer_call":
action = item.action
action_type = action.type
# Convert object attributes to dictionary
action_args = {}
for attr in dir(action):
if attr.startswith('_') or attr == 'type':
continue
try:
action_args[attr] = getattr(action, attr)
except AttributeError:
pass
logger.warning(f"Original Action: {action}")
result_code = self._convert_cua_action_to_pyautogui_action(action_type, action_args)
if result_code:
return {
"action_space": "pyautogui",
"action": result_code,
"pending_checks": item.pending_safety_checks,
"call_id": item.call_id
}
return None
def _convert_cua_action_to_pyautogui_action(self, action_type, args):
"""Convert a CUA action to a pyautogui action format
This function converts OpenAI CUA actions to pyautogui commands
for the Computer Agent Arena
Args:
action_type: Type of the CUA action
args: Arguments for the action
Returns:
String with pyautogui command code or None if the action can't be converted
"""
if not action_type:
logger.warning("Empty CUA action received")
return None
key_mapping = {
"/": "/",
"\\": "\\",
"alt": "alt",
"arrowdown": "down",
"arrowleft": "left",
"arrowright": "right",
"arrowup": "up",
"backspace": "backspace",
"capslock": "capslock",
"cmd": "command",
"ctrl": "ctrl",
"delete": "delete",
"end": "end",
"enter": "enter",
"esc": "esc",
"home": "home",
"insert": "insert",
"option": "option",
"pagedown": "pagedown",
"pageup": "pageup",
"shift": "shift",
"space": "space",
"super": "super",
"tab": "tab",
"win": "win",
}
try:
if action_type == "click":
x = args.get("x")
y = args.get("y")
button = args.get("button", "left")
# Validate coordinates
if x is None or y is None:
logger.warning(f"Invalid click coordinates: x={x}, y={y}")
return None
# Validate button
if button not in ["left", "middle", "right"]:
logger.warning(f"Invalid click button: {button}, defaulting to 'left'")
button = "left"
return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.click(button='{button}')"
elif action_type == "double_click":
x = args.get("x")
y = args.get("y")
# Validate coordinates
if x is None or y is None:
logger.warning(f"Invalid double_click coordinates: x={x}, y={y}")
return None
return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.doubleClick()"
elif action_type == "type":
text = args.get("text", "")
if not text:
logger.warning("Empty text for type action")
return "import pyautogui\n# Empty text, no action taken"
pattern = r"(?<!\\)'"
text = re.sub(pattern, r"\\'", text)
# 使用三重引号来确保字符串内容不会破坏格式
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
logger.info(f"Pyautogui code: {pyautogui_code}")
return pyautogui_code
elif action_type == "keypress":
keys = args.get("keys", [])
if not keys:
logger.warning("Empty keys for keypress action")
return None
# Map to pyautogui keys and normalize
mapped_keys = []
for key in keys:
if isinstance(key, str):
# For Linux compatibility, handle the key mapping more thoroughly
mapped_key = key_mapping.get(key, key).lower()
# Also try lowercase version if not found
if mapped_key == key and key.lower() != key:
mapped_key = key_mapping.get(key.lower(), key)
mapped_keys.append(mapped_key)
if not mapped_keys:
return None
# Format for pyautogui.hotkey
keys_str = ", ".join([f"'{k}'" for k in mapped_keys])
return f"import pyautogui\npyautogui.hotkey({keys_str})"
elif action_type == "scroll":
x = args.get("x", None)
y = args.get("y", None)
scroll_x = args.get("scroll_x", 0)
scroll_y = args.get("scroll_y", 0)
# Normalize scroll values (Linux might use different scaling)
scroll_y = int(scroll_y) if scroll_y else 0
scroll_x = int(scroll_x) if scroll_x else 0
# Default to current mouse position if coordinates not provided
position_str = ""
if x is not None and y is not None:
position_str = f", x={x}, y={y}"
# Handle scroll direction
if scroll_y != 0:
# Convert to clicks - normalize the amount
clicks = scroll_y
return f"import pyautogui\npyautogui.scroll({clicks * (-1)}{position_str})"
elif scroll_x != 0:
# Convert to clicks - normalize the amount
clicks = scroll_x
return f"import pyautogui\npyautogui.hscroll({clicks * (-1)}{position_str})"
else:
logger.warning("Scroll action with zero scrolling amount")
return None
elif action_type == "move":
x = args.get("x")
y = args.get("y")
# Validate coordinates
if x is None or y is None:
logger.warning(f"Invalid move coordinates: x={x}, y={y}")
return None
return f"import pyautogui\npyautogui.moveTo({x}, {y})"
elif action_type == "drag":
if isinstance(args, dict):
path = args.get("path", None)
else:
path = args.path
if not path or len(path) < 2:
logger.warning("Drag path must have at least two points")
return None
# Extract start and end points
start = path[0]
end = path[-1]
# Validate path coordinates - handle different object formats
valid_path = True
for point in path:
if isinstance(point, (list, tuple)) and len(point) == 2:
continue
elif isinstance(point, dict) and 'x' in point and 'y' in point:
continue
elif hasattr(point, 'x') and hasattr(point, 'y'):
continue
else:
valid_path = False
break
if not valid_path:
logger.warning("Invalid path format for drag action")
return None
if len(path) == 2:
# Extract coordinates, handling different formats
if isinstance(start, (list, tuple)):
start_x, start_y = start
elif isinstance(start, dict):
start_x, start_y = start.get('x'), start.get('y')
else: # object with attributes
start_x, start_y = start.x, start.y
if isinstance(end, (list, tuple)):
end_x, end_y = end
elif isinstance(end, dict):
end_x, end_y = end.get('x'), end.get('y')
else: # object with attributes
end_x, end_y = end.x, end.y
return (
f"import pyautogui\n"
f"pyautogui.moveTo({start_x}, {start_y})\n"
f"pyautogui.dragTo({end_x}, {end_y}, duration=0.5, button='left')"
)
# For complex paths with multiple points
else:
actions = []
# Handle first point
if isinstance(path[0], (list, tuple)):
first_x, first_y = path[0]
elif isinstance(path[0], dict):
first_x, first_y = path[0].get('x'), path[0].get('y')
else: # object with attributes
first_x, first_y = path[0].x, path[0].y
actions.append(f"import pyautogui\npyautogui.moveTo({first_x}, {first_y})")
for i in range(1, len(path)):
if isinstance(path[i], (list, tuple)):
x, y = path[i]
elif isinstance(path[i], dict):
x, y = path[i].get('x'), path[i].get('y')
else: # object with attributes
x, y = path[i].x, path[i].y
actions.append(f"pyautogui.dragTo({x}, {y}, duration=0.2, button='left')")
return "\n".join(actions)
elif action_type == "wait":
ms = args.get("ms", 1000) # Default to 1000ms (1 second)
seconds = max(0.1, ms / 1000) # Ensure minimum wait time
return f"import time\ntime.sleep({seconds})"
elif action_type == "screenshot":
# Just return a wait action, as screenshots are handled automatically
return "import time\ntime.sleep(0.1) # Screenshot requested, no direct action needed"
else:
logger.warning(f"Unknown action type: {action_type}")
return None
except Exception as e:
logger.exception(f"Error converting CUA action to agent action: {e}")
return None
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
base64_image = encode_image(obs["screenshot"])
if self.cua_messages == []:
self.cua_messages.append({
"role": "user",
"content": [
{
"type": "input_image",
"image_url": f"data:image/png;base64,{base64_image}",
},
{
"type": "input_text",
"text": instruction
}
]
})
with Timer() as model_timer:
response = self._create_response()
self.cua_messages += response.output
actions = []
responses = []
action_exit = False
thought_exit = False
message_exit = False
for item in response.output:
parsed_item = self._handle_item(item)
if isinstance(parsed_item, dict) and parsed_item.get("action_space", None) == "pyautogui":
actions.append(parsed_item)
else:
responses.append(parsed_item)
if item.type == "computer_call":
action_exit = True
if item.type == "reasoning" and item.summary and item.summary[0].type == "summary_text":
thought_exit = True
if item.type == "message" and item.content and item.content[0].type == "output_text":
message_exit = True
responses = [item for item in responses if item is not None]
logger.info(f"Actions: {actions}")
logger.info(f"Responses: {responses}")
state_correct = False
# if action_exit and thought_exit:
# state_correct = True
if action_exit and not message_exit:
state_correct = True
if not state_correct:
logger.warning("The state of the agent is not correct, action_exit: %s, thought_exit: %s, message_exit: %s", action_exit, thought_exit, message_exit)
predict_info = {
"model_usage": {
"model_time": model_timer.duration,
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
},
"messages": self.cua_messages,
"response": "\n".join(responses) if isinstance(responses, list) and all(isinstance(item, str) for item in responses) else "",
"state_correct": state_correct,
}
return predict_info, actions
def reset(self, _logger=None):
global logger
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
self.thoughts = []
self.actions = []
self.observations = []
self.cua_messages = []
def step(self, action: Dict[str, Any]) -> Tuple[bool, Dict[str, Any]]:
"""Execute an action in the environment.
Args:
action: The action to execute
Returns:
Tuple containing:
- terminated: Whether the episode has terminated
- info: Information about the step
Raises:
StepError: If the step execution fails
"""
try:
if not action:
logger.warning("Empty action received, terminating episode")
return True, {}
logger.info(f"Executing action: {action.get('action_space', 'unknown')} - {action.get('action', '')[:50]}...")
with Timer() as step_timer:
# Convert the action to an Action object
step_action = Action(action.get("action", ""), self.action_space)
# Execute the action in the environment
obs, reward, terminated, info = self.env.step(step_action.get_action())
screenshot_base64 = encode_image(obs["screenshot"])
self.cua_messages.append({
"type": "computer_call_output",
"call_id": action["call_id"],
"acknowledged_safety_checks": action["pending_checks"],
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
})
logger.debug(f"Action completed in {step_timer.duration:.2f}s")
if terminated:
logger.info("Environment signaled termination")
return obs, reward, terminated, info, {
"step_time": step_timer.duration,
"action": action
}
except Exception as e:
logger.exception(f"Environment step failed: {str(e)}")
raise StepError(f"Failed to execute step: {str(e)}")
class StepError(Exception):
"""Exception raised when a step in the agent fails."""
pass

357
run_multienv_openaicua.py Normal file
View File

@@ -0,0 +1,357 @@
"""Script to run end-to-end evaluation on the benchmark.
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import datetime
import json
import logging
import os
import sys
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.openai_cua_agent import OpenAICUAAgent
# import wandb
# Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="gpt-4o")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--max_tokens", type=int, default=1500)
parser.add_argument("--stop_token", type=str, default=None)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
args = parser.parse_args()
return args
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
"""Distribute tasks evenly across environments."""
# Flatten the tasks into a single list
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment."""
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name="aws",
region="us-east-1",
snapshot_name="ami-05e7d7bd279ea4f14",
screen_size=(args.screen_width, args.screen_height),
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
agent = OpenAICUAAgent(
env=env,
model=args.model,
max_tokens=args.max_tokens,
top_p=args.top_p,
temperature=args.temperature,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
)
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example_openaicua(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close()
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
logger.info("All environments are ready. Starting parallel task execution...")
# Create a shared list for scores across processes
with Manager() as manager:
shared_scores = manager.list()
# Create and start processes for each environment
processes = []
for env_idx, env_tasks in enumerate(distributed_tasks):
p = Process(
target=run_env_tasks,
args=(env_idx, env_tasks, args, shared_scores)
)
processes.append(p)
p.start()
# Wait for all processes to complete
for p in processes:
p.join()
# Convert shared list to regular list
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)