@@ -62,3 +62,58 @@ def setup_logger(example, example_result_dir):
|
||||
runtime_logger.setLevel(logging.DEBUG)
|
||||
runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log")))
|
||||
return runtime_logger
|
||||
|
||||
def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
obs
|
||||
)
|
||||
|
||||
done = not response.get('state_correct', False)
|
||||
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
obs, reward, done, info, step_info = agent.step(action)
|
||||
|
||||
if not done:
|
||||
if not response.get('state_correct', False):
|
||||
done = True
|
||||
|
||||
logger.info("Reward: %.2f", reward)
|
||||
logger.info("Done: %s", done)
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
759
mm_agents/openai_cua_agent.py
Normal file
759
mm_agents/openai_cua_agent.py
Normal file
@@ -0,0 +1,759 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
import dashscope
|
||||
import google.generativeai as genai
|
||||
import openai
|
||||
import requests
|
||||
import tiktoken
|
||||
from PIL import Image
|
||||
from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest
|
||||
from groq import Groq
|
||||
from requests.exceptions import SSLError
|
||||
from typing import Any, Optional, Union, Tuple
|
||||
|
||||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
||||
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
||||
SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
pure_text_settings = ['a11y_tree']
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||||
import ast
|
||||
from typing import Dict, Any, Optional, Union
|
||||
|
||||
class Action:
|
||||
"""Action class for the agent."""
|
||||
def __init__(self, raw_action: Union[Dict, str], action_space: str):
|
||||
"""Initialize the Action class.
|
||||
|
||||
Args:
|
||||
raw_action: The raw action
|
||||
action_space: The action space
|
||||
"""
|
||||
self._action_space = None
|
||||
self._action = None
|
||||
self.action_space = action_space
|
||||
self.action = raw_action
|
||||
|
||||
@property
|
||||
def action(self) -> str:
|
||||
return self._action
|
||||
|
||||
@property
|
||||
def action_space(self) -> str:
|
||||
return self._action_space
|
||||
|
||||
@action_space.setter
|
||||
def action_space(self, value: str):
|
||||
"""
|
||||
Set the action space for the agent.
|
||||
Currently only supports 'pyautogui' as a valid action space.
|
||||
|
||||
Args:
|
||||
value (str): The action space to set
|
||||
|
||||
Raises:
|
||||
ValueError: If action_space is empty or invalid
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("action_space is required")
|
||||
if value not in ["pyautogui", "claude_computer_use"]:
|
||||
raise ValueError(
|
||||
"Invalid action space. Allowed spaces are: pyautogui")
|
||||
self._action_space = value
|
||||
|
||||
|
||||
|
||||
@action.setter
|
||||
def action(self, value: Optional[str]):
|
||||
"""
|
||||
Set the action for the agent.
|
||||
For pyautogui action space, accepts special commands (WAIT, FAIL, DONE) or valid Python code.
|
||||
For claude_computer_use action space, accepts a dict with keys "name", "input" and "id".
|
||||
|
||||
Args:
|
||||
value (str | dict): The action to set
|
||||
|
||||
Raises:
|
||||
ValueError: If action is empty or invalid
|
||||
"""
|
||||
if not value:
|
||||
raise ValueError("action cannot be empty")
|
||||
|
||||
if self._action_space == "pyautogui":
|
||||
self._action = value
|
||||
# if value in ["WAIT", "FAIL", "DONE"]:
|
||||
# self._action = value
|
||||
# elif self._is_valid_python_code(value):
|
||||
# self._action = value
|
||||
# else:
|
||||
# raise ValueError("Invalid action format for pyautogui")
|
||||
elif self._action_space == "claude_computer_use":
|
||||
self._action = value
|
||||
# if self._is_valid_claude_computer_use_action(value):
|
||||
# self._action = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid action space: {self._action_space}, allowed spaces are: pyautogui, claude_computer_use")
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the Action instance.
|
||||
|
||||
Returns:
|
||||
str: A string showing the action space and action value
|
||||
"""
|
||||
return f"Action(action_space='{self._action_space}', action='{self._action}')"
|
||||
|
||||
def get_action(self) -> Optional[str]:
|
||||
"""Get the action.
|
||||
|
||||
Returns:
|
||||
str: The action
|
||||
"""
|
||||
return self._action
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert the action to a dictionary.
|
||||
|
||||
Returns:
|
||||
dict: The action as a dictionary
|
||||
"""
|
||||
return {"action_space": self._action_space, "action": self._action}
|
||||
|
||||
def _is_valid_python_code(self, code: str) -> bool:
|
||||
"""
|
||||
Validate if the given string is valid Python code syntax.
|
||||
|
||||
Args:
|
||||
code (str): The code string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if code is valid Python syntax, False otherwise
|
||||
"""
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True
|
||||
except SyntaxError:
|
||||
raise ValueError("Invalid Python code syntax")
|
||||
|
||||
def _is_valid_claude_computer_use_action(self, action: Dict[str, Any]) -> bool:
|
||||
"""Validate if the given action is valid for the claude_computer_use action space.
|
||||
|
||||
Args:
|
||||
action: The action to validate
|
||||
|
||||
Returns:
|
||||
bool: True if action is valid, False otherwise
|
||||
"""
|
||||
if not isinstance(action, dict):
|
||||
raise ValueError("Invalid action format for claude_computer_use")
|
||||
if not (action.get("name") and action.get("input") and action.get("id")):
|
||||
raise ValueError(
|
||||
"Invalid action format for claude_computer_use, 'name', 'input' and 'id' are required")
|
||||
return True
|
||||
|
||||
class Timer:
|
||||
"""Context manager for timing code blocks."""
|
||||
|
||||
def __enter__(self):
|
||||
self.start = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.duration = time.time() - self.start
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def save_to_tmp_img_file(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
|
||||
image.save(tmp_img_path)
|
||||
|
||||
return tmp_img_path
|
||||
|
||||
|
||||
class OpenAICUAAgent:
|
||||
def __init__(
|
||||
self,
|
||||
env,
|
||||
platform="ubuntu",
|
||||
model="computer-use-preview",
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0.5,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot_a11y_tree",
|
||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||
max_trajectory_length=100,
|
||||
a11y_tree_max_tokens=10000
|
||||
):
|
||||
self.env = env
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||
self.cua_messages : List[Dict] = []
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
|
||||
self.tools = [{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": 1920,
|
||||
"display_height": 1080,
|
||||
"environment": "linux" if platform == "ubuntu" else "windows"
|
||||
}]
|
||||
|
||||
if observation_type == "screenshot":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif observation_type == "a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif observation_type == "screenshot_a11y_tree":
|
||||
if action_space == "computer_13":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif observation_type == "som":
|
||||
if action_space == "computer_13":
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
elif action_space == "pyautogui":
|
||||
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
||||
else:
|
||||
raise ValueError("Invalid action space: " + action_space)
|
||||
else:
|
||||
raise ValueError("Invalid experiment type: " + observation_type)
|
||||
|
||||
def _create_response(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""Create a response from the OpenAI API.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional arguments to pass to the API
|
||||
|
||||
Returns:
|
||||
The API response as a dictionary
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If the API request fails
|
||||
"""
|
||||
retry_count = 0
|
||||
while retry_count < 3:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY_CUA"))
|
||||
response = client.responses.create(
|
||||
model=self.model,
|
||||
input=self.cua_messages,
|
||||
tools=self.tools,
|
||||
reasoning={
|
||||
"generate_summary": "concise",
|
||||
},
|
||||
truncation="auto",
|
||||
)
|
||||
logger.debug(f"Received successful response from OpenAI API")
|
||||
logger.info(f"Response: {response}")
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI API error: {str(e)}")
|
||||
new_screenshot = self.env._get_obs()
|
||||
new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8')
|
||||
self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}"
|
||||
retry_count += 1
|
||||
time.sleep(1)
|
||||
raise Exception("Failed to make OpenAI API call after 3 retries")
|
||||
|
||||
def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]:
|
||||
"""Parse a response item from the OpenAI API.
|
||||
|
||||
Args:
|
||||
item: The response item to parse
|
||||
|
||||
Returns:
|
||||
The parsed item as either a string message or a dictionary containing action information,
|
||||
or None if the item couldn't be parsed
|
||||
"""
|
||||
if item.type == "message":
|
||||
if item.content is not None:
|
||||
response = item.content[0] if isinstance(item.content, list) else item.content
|
||||
response_type = response.type
|
||||
response_text = response.text
|
||||
logger.info(f"Received response text: {response_type} - {response_text}")
|
||||
if response_type == "output_text":
|
||||
return response_text
|
||||
return None
|
||||
return None
|
||||
|
||||
if item.type == "function_call":
|
||||
return None
|
||||
|
||||
if item.type == "reasoning":
|
||||
reasoning = item.summary
|
||||
if isinstance(reasoning, list):
|
||||
reasoning_item = reasoning[0]
|
||||
reasoning_text = reasoning_item.text
|
||||
reasoning_type = reasoning_item.type
|
||||
if reasoning_type == "summary_text":
|
||||
return reasoning_text
|
||||
return None
|
||||
return None
|
||||
|
||||
if item.type == "computer_call":
|
||||
action = item.action
|
||||
action_type = action.type
|
||||
# Convert object attributes to dictionary
|
||||
action_args = {}
|
||||
for attr in dir(action):
|
||||
if attr.startswith('_') or attr == 'type':
|
||||
continue
|
||||
try:
|
||||
action_args[attr] = getattr(action, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
logger.warning(f"Original Action: {action}")
|
||||
result_code = self._convert_cua_action_to_pyautogui_action(action_type, action_args)
|
||||
if result_code:
|
||||
return {
|
||||
"action_space": "pyautogui",
|
||||
"action": result_code,
|
||||
"pending_checks": item.pending_safety_checks,
|
||||
"call_id": item.call_id
|
||||
}
|
||||
return None
|
||||
|
||||
def _convert_cua_action_to_pyautogui_action(self, action_type, args):
|
||||
"""Convert a CUA action to a pyautogui action format
|
||||
|
||||
This function converts OpenAI CUA actions to pyautogui commands
|
||||
for the Computer Agent Arena
|
||||
|
||||
Args:
|
||||
action_type: Type of the CUA action
|
||||
args: Arguments for the action
|
||||
|
||||
Returns:
|
||||
String with pyautogui command code or None if the action can't be converted
|
||||
"""
|
||||
if not action_type:
|
||||
logger.warning("Empty CUA action received")
|
||||
return None
|
||||
|
||||
key_mapping = {
|
||||
"/": "/",
|
||||
"\\": "\\",
|
||||
"alt": "alt",
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"arrowup": "up",
|
||||
"backspace": "backspace",
|
||||
"capslock": "capslock",
|
||||
"cmd": "command",
|
||||
"ctrl": "ctrl",
|
||||
"delete": "delete",
|
||||
"end": "end",
|
||||
"enter": "enter",
|
||||
"esc": "esc",
|
||||
"home": "home",
|
||||
"insert": "insert",
|
||||
"option": "option",
|
||||
"pagedown": "pagedown",
|
||||
"pageup": "pageup",
|
||||
"shift": "shift",
|
||||
"space": "space",
|
||||
"super": "super",
|
||||
"tab": "tab",
|
||||
"win": "win",
|
||||
}
|
||||
try:
|
||||
if action_type == "click":
|
||||
x = args.get("x")
|
||||
y = args.get("y")
|
||||
button = args.get("button", "left")
|
||||
|
||||
# Validate coordinates
|
||||
if x is None or y is None:
|
||||
logger.warning(f"Invalid click coordinates: x={x}, y={y}")
|
||||
return None
|
||||
|
||||
# Validate button
|
||||
if button not in ["left", "middle", "right"]:
|
||||
logger.warning(f"Invalid click button: {button}, defaulting to 'left'")
|
||||
button = "left"
|
||||
|
||||
return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.click(button='{button}')"
|
||||
|
||||
elif action_type == "double_click":
|
||||
x = args.get("x")
|
||||
y = args.get("y")
|
||||
|
||||
# Validate coordinates
|
||||
if x is None or y is None:
|
||||
logger.warning(f"Invalid double_click coordinates: x={x}, y={y}")
|
||||
return None
|
||||
|
||||
return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.doubleClick()"
|
||||
|
||||
elif action_type == "type":
|
||||
text = args.get("text", "")
|
||||
|
||||
if not text:
|
||||
logger.warning("Empty text for type action")
|
||||
return "import pyautogui\n# Empty text, no action taken"
|
||||
|
||||
pattern = r"(?<!\\)'"
|
||||
text = re.sub(pattern, r"\\'", text)
|
||||
|
||||
# 使用三重引号来确保字符串内容不会破坏格式
|
||||
pyautogui_code = f"""import pyautogui\npyautogui.typewrite({repr(text)})"""
|
||||
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||
return pyautogui_code
|
||||
|
||||
elif action_type == "keypress":
|
||||
keys = args.get("keys", [])
|
||||
|
||||
if not keys:
|
||||
logger.warning("Empty keys for keypress action")
|
||||
return None
|
||||
|
||||
# Map to pyautogui keys and normalize
|
||||
mapped_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, str):
|
||||
# For Linux compatibility, handle the key mapping more thoroughly
|
||||
mapped_key = key_mapping.get(key, key).lower()
|
||||
# Also try lowercase version if not found
|
||||
if mapped_key == key and key.lower() != key:
|
||||
mapped_key = key_mapping.get(key.lower(), key)
|
||||
mapped_keys.append(mapped_key)
|
||||
|
||||
if not mapped_keys:
|
||||
return None
|
||||
|
||||
# Format for pyautogui.hotkey
|
||||
keys_str = ", ".join([f"'{k}'" for k in mapped_keys])
|
||||
|
||||
return f"import pyautogui\npyautogui.hotkey({keys_str})"
|
||||
|
||||
elif action_type == "scroll":
|
||||
x = args.get("x", None)
|
||||
y = args.get("y", None)
|
||||
scroll_x = args.get("scroll_x", 0)
|
||||
scroll_y = args.get("scroll_y", 0)
|
||||
|
||||
# Normalize scroll values (Linux might use different scaling)
|
||||
scroll_y = int(scroll_y) if scroll_y else 0
|
||||
scroll_x = int(scroll_x) if scroll_x else 0
|
||||
|
||||
# Default to current mouse position if coordinates not provided
|
||||
position_str = ""
|
||||
if x is not None and y is not None:
|
||||
position_str = f", x={x}, y={y}"
|
||||
|
||||
# Handle scroll direction
|
||||
if scroll_y != 0:
|
||||
# Convert to clicks - normalize the amount
|
||||
clicks = scroll_y
|
||||
return f"import pyautogui\npyautogui.scroll({clicks * (-1)}{position_str})"
|
||||
elif scroll_x != 0:
|
||||
# Convert to clicks - normalize the amount
|
||||
clicks = scroll_x
|
||||
return f"import pyautogui\npyautogui.hscroll({clicks * (-1)}{position_str})"
|
||||
else:
|
||||
logger.warning("Scroll action with zero scrolling amount")
|
||||
return None
|
||||
|
||||
elif action_type == "move":
|
||||
x = args.get("x")
|
||||
y = args.get("y")
|
||||
|
||||
# Validate coordinates
|
||||
if x is None or y is None:
|
||||
logger.warning(f"Invalid move coordinates: x={x}, y={y}")
|
||||
return None
|
||||
|
||||
return f"import pyautogui\npyautogui.moveTo({x}, {y})"
|
||||
|
||||
elif action_type == "drag":
|
||||
if isinstance(args, dict):
|
||||
path = args.get("path", None)
|
||||
else:
|
||||
path = args.path
|
||||
|
||||
if not path or len(path) < 2:
|
||||
logger.warning("Drag path must have at least two points")
|
||||
return None
|
||||
|
||||
# Extract start and end points
|
||||
start = path[0]
|
||||
end = path[-1]
|
||||
|
||||
# Validate path coordinates - handle different object formats
|
||||
valid_path = True
|
||||
for point in path:
|
||||
if isinstance(point, (list, tuple)) and len(point) == 2:
|
||||
continue
|
||||
elif isinstance(point, dict) and 'x' in point and 'y' in point:
|
||||
continue
|
||||
elif hasattr(point, 'x') and hasattr(point, 'y'):
|
||||
continue
|
||||
else:
|
||||
valid_path = False
|
||||
break
|
||||
|
||||
if not valid_path:
|
||||
logger.warning("Invalid path format for drag action")
|
||||
return None
|
||||
|
||||
if len(path) == 2:
|
||||
# Extract coordinates, handling different formats
|
||||
if isinstance(start, (list, tuple)):
|
||||
start_x, start_y = start
|
||||
elif isinstance(start, dict):
|
||||
start_x, start_y = start.get('x'), start.get('y')
|
||||
else: # object with attributes
|
||||
start_x, start_y = start.x, start.y
|
||||
|
||||
if isinstance(end, (list, tuple)):
|
||||
end_x, end_y = end
|
||||
elif isinstance(end, dict):
|
||||
end_x, end_y = end.get('x'), end.get('y')
|
||||
else: # object with attributes
|
||||
end_x, end_y = end.x, end.y
|
||||
|
||||
return (
|
||||
f"import pyautogui\n"
|
||||
f"pyautogui.moveTo({start_x}, {start_y})\n"
|
||||
f"pyautogui.dragTo({end_x}, {end_y}, duration=0.5, button='left')"
|
||||
)
|
||||
# For complex paths with multiple points
|
||||
else:
|
||||
actions = []
|
||||
# Handle first point
|
||||
if isinstance(path[0], (list, tuple)):
|
||||
first_x, first_y = path[0]
|
||||
elif isinstance(path[0], dict):
|
||||
first_x, first_y = path[0].get('x'), path[0].get('y')
|
||||
else: # object with attributes
|
||||
first_x, first_y = path[0].x, path[0].y
|
||||
|
||||
actions.append(f"import pyautogui\npyautogui.moveTo({first_x}, {first_y})")
|
||||
|
||||
for i in range(1, len(path)):
|
||||
if isinstance(path[i], (list, tuple)):
|
||||
x, y = path[i]
|
||||
elif isinstance(path[i], dict):
|
||||
x, y = path[i].get('x'), path[i].get('y')
|
||||
else: # object with attributes
|
||||
x, y = path[i].x, path[i].y
|
||||
|
||||
actions.append(f"pyautogui.dragTo({x}, {y}, duration=0.2, button='left')")
|
||||
|
||||
return "\n".join(actions)
|
||||
|
||||
elif action_type == "wait":
|
||||
ms = args.get("ms", 1000) # Default to 1000ms (1 second)
|
||||
seconds = max(0.1, ms / 1000) # Ensure minimum wait time
|
||||
|
||||
return f"import time\ntime.sleep({seconds})"
|
||||
|
||||
elif action_type == "screenshot":
|
||||
# Just return a wait action, as screenshots are handled automatically
|
||||
return "import time\ntime.sleep(0.1) # Screenshot requested, no direct action needed"
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown action type: {action_type}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error converting CUA action to agent action: {e}")
|
||||
return None
|
||||
|
||||
def predict(self, instruction: str, obs: Dict) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
|
||||
base64_image = encode_image(obs["screenshot"])
|
||||
if self.cua_messages == []:
|
||||
self.cua_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{base64_image}",
|
||||
},
|
||||
{
|
||||
"type": "input_text",
|
||||
"text": instruction
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
with Timer() as model_timer:
|
||||
response = self._create_response()
|
||||
self.cua_messages += response.output
|
||||
|
||||
actions = []
|
||||
responses = []
|
||||
action_exit = False
|
||||
thought_exit = False
|
||||
message_exit = False
|
||||
for item in response.output:
|
||||
parsed_item = self._handle_item(item)
|
||||
if isinstance(parsed_item, dict) and parsed_item.get("action_space", None) == "pyautogui":
|
||||
actions.append(parsed_item)
|
||||
else:
|
||||
responses.append(parsed_item)
|
||||
if item.type == "computer_call":
|
||||
action_exit = True
|
||||
if item.type == "reasoning" and item.summary and item.summary[0].type == "summary_text":
|
||||
thought_exit = True
|
||||
if item.type == "message" and item.content and item.content[0].type == "output_text":
|
||||
message_exit = True
|
||||
responses = [item for item in responses if item is not None]
|
||||
|
||||
logger.info(f"Actions: {actions}")
|
||||
logger.info(f"Responses: {responses}")
|
||||
|
||||
state_correct = False
|
||||
# if action_exit and thought_exit:
|
||||
# state_correct = True
|
||||
if action_exit and not message_exit:
|
||||
state_correct = True
|
||||
if not state_correct:
|
||||
logger.warning("The state of the agent is not correct, action_exit: %s, thought_exit: %s, message_exit: %s", action_exit, thought_exit, message_exit)
|
||||
|
||||
|
||||
predict_info = {
|
||||
"model_usage": {
|
||||
"model_time": model_timer.duration,
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
},
|
||||
"messages": self.cua_messages,
|
||||
"response": "\n".join(responses) if isinstance(responses, list) and all(isinstance(item, str) for item in responses) else "",
|
||||
"state_correct": state_correct,
|
||||
}
|
||||
|
||||
return predict_info, actions
|
||||
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.cua_messages = []
|
||||
|
||||
def step(self, action: Dict[str, Any]) -> Tuple[bool, Dict[str, Any]]:
|
||||
"""Execute an action in the environment.
|
||||
|
||||
Args:
|
||||
action: The action to execute
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- terminated: Whether the episode has terminated
|
||||
- info: Information about the step
|
||||
|
||||
Raises:
|
||||
StepError: If the step execution fails
|
||||
"""
|
||||
try:
|
||||
if not action:
|
||||
logger.warning("Empty action received, terminating episode")
|
||||
return True, {}
|
||||
|
||||
logger.info(f"Executing action: {action.get('action_space', 'unknown')} - {action.get('action', '')[:50]}...")
|
||||
|
||||
with Timer() as step_timer:
|
||||
# Convert the action to an Action object
|
||||
step_action = Action(action.get("action", ""), self.action_space)
|
||||
# Execute the action in the environment
|
||||
obs, reward, terminated, info = self.env.step(step_action.get_action())
|
||||
|
||||
screenshot_base64 = encode_image(obs["screenshot"])
|
||||
|
||||
self.cua_messages.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": action["call_id"],
|
||||
"acknowledged_safety_checks": action["pending_checks"],
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
})
|
||||
|
||||
logger.debug(f"Action completed in {step_timer.duration:.2f}s")
|
||||
if terminated:
|
||||
logger.info("Environment signaled termination")
|
||||
|
||||
return obs, reward, terminated, info, {
|
||||
"step_time": step_timer.duration,
|
||||
"action": action
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Environment step failed: {str(e)}")
|
||||
raise StepError(f"Failed to execute step: {str(e)}")
|
||||
|
||||
class StepError(Exception):
|
||||
"""Exception raised when a step in the agent fails."""
|
||||
pass
|
||||
357
run_multienv_openaicua.py
Normal file
357
run_multienv_openaicua.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Script to run end-to-end evaluation on the benchmark.
|
||||
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.openai_cua_agent import OpenAICUAAgent
|
||||
|
||||
# import wandb
|
||||
|
||||
|
||||
# Logger Configs {{{ #
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
sdebug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(logging.INFO)
|
||||
sdebug_handler.setLevel(logging.DEBUG)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
sdebug_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
logger.addHandler(sdebug_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--max_trajectory_length", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
"""Distribute tasks evenly across environments."""
|
||||
# Flatten the tasks into a single list
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
|
||||
# Calculate tasks per environment
|
||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||
|
||||
# Distribute tasks
|
||||
distributed_tasks = []
|
||||
for i in range(num_envs):
|
||||
env_tasks = {}
|
||||
start_idx = i * tasks_per_env
|
||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||
|
||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||
if domain not in env_tasks:
|
||||
env_tasks[domain] = []
|
||||
env_tasks[domain].append(example_id)
|
||||
|
||||
distributed_tasks.append(env_tasks)
|
||||
|
||||
return distributed_tasks
|
||||
|
||||
|
||||
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||
"""Run tasks for a single environment."""
|
||||
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
|
||||
provider_name="aws",
|
||||
region="us-east-1",
|
||||
snapshot_name="ami-05e7d7bd279ea4f14",
|
||||
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
agent = OpenAICUAAgent(
|
||||
env=env,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
max_trajectory_length=args.max_trajectory_length,
|
||||
)
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_openaicua(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||
|
||||
logger.info("All environments are ready. Starting parallel task execution...")
|
||||
|
||||
# Create a shared list for scores across processes
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Create and start processes for each environment
|
||||
processes = []
|
||||
for env_idx, env_tasks in enumerate(distributed_tasks):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(env_idx, env_tasks, args, shared_scores)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# Convert shared list to regular list
|
||||
scores = list(shared_scores)
|
||||
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
args = config()
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
Reference in New Issue
Block a user