Wxy/opencua (#256)
* OpenCUA Agent code base * update url * debug, modify url input
This commit is contained in:
@@ -146,6 +146,61 @@ def run_single_example_openaicua(agent, env, example, max_steps, instruction, ar
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
|
||||
def run_single_example_opencua(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
runtime_logger = setup_logger(example, example_result_dir)
|
||||
agent.reset(runtime_logger)
|
||||
env.reset(task_config=example)
|
||||
time.sleep(60) # Wait for the environment to be ready
|
||||
obs = env._get_obs() # Get the initial observation
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions, info_dict = agent.predict(instruction, obs)
|
||||
|
||||
logger.info(f"Got Action: {actions}")
|
||||
if not actions or len(actions)==0 or actions[0]=="" or actions[0].lower().startswith("error"): # TODO: new added
|
||||
break
|
||||
|
||||
for action in actions:
|
||||
# Capture the timestamp before executing the action
|
||||
action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
obs, reward, done, info = env.step(action)
|
||||
time.sleep(3)
|
||||
obs = env._get_obs()
|
||||
|
||||
logger.info(f"Action {action} executed, reward: {reward}, done: {done}")
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
"action_timestamp": action_timestamp,
|
||||
"action": action,
|
||||
"response": response,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"info": info,
|
||||
"screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png"
|
||||
}))
|
||||
f.write("\n")
|
||||
if done:
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
725
mm_agents/opencua_agent.py
Normal file
725
mm_agents/opencua_agent.py
Normal file
@@ -0,0 +1,725 @@
|
||||
import base64
|
||||
from loguru import logger
|
||||
import re
|
||||
import time
|
||||
import math
|
||||
import httpx
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
import backoff
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
AGNET_SYS_PROMPT_L1 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}""".strip()
|
||||
|
||||
AGNET_SYS_PROMPT_L2 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}".strip()
|
||||
|
||||
AGNET_SYS_PROMPT_L3 = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.\n\nFor each step, provide your response in this format:\n\nObservation:\n - Describe the current computer state based on the full screenshot in detail. \n - Application Context:\n - The active application\n - The active window or page\n - Overall layout and visible interface\n - Key Elements:\n - Menu items and toolbars \n - Buttons and controls\n - Text fields and content\n - Dialog boxes or popups\n - Error messages or notifications\n - Loading states\n - Other key elements\n - Describe any content, elements, options, information or clues that are possibly relevant to achieving the task goal, including their name, content, or shape (if possible).\n\nThought:\n - Step by Step Progress Assessment:\n - Analyze completed task parts and their contribution to the overall goal\n - Reflect on potential errors, unexpected results, or obstacles\n - If previous action was incorrect, predict a logical recovery step\n - Next Action Analysis:\n - List possible next actions based on current state\n - Evaluate options considering current state and previous actions\n - Propose most logical next action\n - Anticipate consequences of the proposed action\n - For Text Input Actions:\n - Note current cursor position\n - Consolidate repetitive actions (specify count for multiple keypresses)\n - Describe expected final text outcome\n - Use first-person perspective in reasoning\n\nAction:\n Provide clear, concise, and actionable instructions:\n - If the action involves interacting with a specific target:\n - Describe target explicitly without using coordinates\n - Specify element names when possible (use original language if non-English)\n - Describe features (shape, color, position) if name unavailable\n - For window control buttons, identify correctly (minimize \"—\", maximize \"□\", close \"X\")\n - if the action involves keyboard actions like 'press', 'write', 'hotkey':\n - Consolidate repetitive keypresses with count\n - Specify expected text outcome for typing actions\n\nFinally, output the action as PyAutoGUI code or the following functions:\n- {\"name\": \"computer.triple_click\", \"description\": \"Triple click on the screen\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"number\", \"description\": \"The x coordinate of the triple click\"}, \"y\": {\"type\": \"number\", \"description\": \"The y coordinate of the triple click\"}}, \"required\": [\"x\", \"y\"]}}\n- {\"name\": \"computer.terminate\", \"description\": \"Terminate the current task and report its completion status\", \"parameters\": {\"type\": \"object\", \"properties\": {\"status\": {\"type\": \"string\", \"enum\": [\"success\", \"fail\"], \"description\": \"The status of the task\"}}, \"required\": [\"status\"]}}\n".strip()
|
||||
|
||||
AGNET_SYS_PROMPT_L0 = """You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.
|
||||
|
||||
For each step, output the action as PyAutoGUI code or the following functions:
|
||||
- {"name": "computer.triple_click", "description": "Triple click on the screen", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The x coordinate of the triple click"}, "y": {"type": "number", "description": "The y coordinate of the triple click"}}, "required": ["x", "y"]}}
|
||||
- {"name": "computer.terminate", "description": "Terminate the current task and report its completion status", "parameters": {"type": "object", "properties": {"status": {"type": "string", "enum": ["success", "failure"], "description": "The status of the task"}}, "required": ["status"]}}
|
||||
""".strip()
|
||||
|
||||
INSTRUTION_TEMPLATE = "# Task Instruction:\n{instruction}\n\nPlease generate the next move according to the screenshot, task instruction and previous steps (if provided).\n"
|
||||
|
||||
STEP_TEMPLATE = "# Step {step_num}:\n"
|
||||
ACTION_HISTORY_TEMPLATE = "## Action:\n{action}\n"
|
||||
THOUGHT_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
OBSERVATION_HISTORY_TEMPLATE = "## Observation:\n{observation}\n\n## Thought:\n{thought}\n\n## Action:\n{action}\n"
|
||||
DETAIL_HISTORY_TEMPLATE = "## Thought:\n{thought}\n\n## Action:\n{action}\n\n## Code:\n{code}\n"
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
def parse_response_to_cot_and_action(input_string, screen_size, coordinate_type) -> Tuple[str, List[str], dict]:
|
||||
"""Parse response including Observation, Thought, Action and code block"""
|
||||
try:
|
||||
sections = {}
|
||||
|
||||
if "computer.terminate" in input_string.lower():
|
||||
code_blocks = re.findall(r'```(?:code)?\s*(.*?)\s*```', input_string, re.DOTALL | re.IGNORECASE)
|
||||
if code_blocks:
|
||||
last_code = code_blocks[-1].strip().lower()
|
||||
if "fail" in last_code:
|
||||
return "FAIL", ["FAIL"], {}
|
||||
elif "success" in last_code:
|
||||
return "DONE", ["DONE"], {}
|
||||
|
||||
return "DONE", ["DONE"], {}
|
||||
|
||||
obs_match = re.search(r'^##\s*Observation\s*:?[\n\r]+(.*?)(?=^##\s*Thought:|^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if obs_match:
|
||||
sections['observation'] = obs_match.group(1).strip()
|
||||
# logger.warning(f"Extracted Observation: {sections.get('observation', 'None')}")
|
||||
|
||||
thought_match = re.search(r'^##\s*Thought\s*:?[\n\r]+(.*?)(?=^##\s*Action:|^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if thought_match:
|
||||
sections['thought'] = thought_match.group(1).strip()
|
||||
# logger.warning(f"Extracted Thought: {sections.get('thought', 'None')}")
|
||||
|
||||
action_match = re.search(r'^##\s*Action\s*:?[\n\r]+(.*?)(?=^##|\Z)', input_string, re.DOTALL | re.MULTILINE)
|
||||
if action_match:
|
||||
action = action_match.group(1).strip()
|
||||
sections['action'] = action.strip()
|
||||
# logger.warning(f"Extracted Action: {sections.get('action', 'None')}")
|
||||
|
||||
code_blocks = re.findall(r'```(?:python)?\s*(.*?)\s*```', input_string, re.DOTALL)
|
||||
if code_blocks:
|
||||
code = code_blocks[-1].strip()
|
||||
sections['original_code'] = transform_agnet_action_to_code_block(code)
|
||||
corrected_code = correct_pyautogui_arguments(code)
|
||||
sections['code'] = corrected_code
|
||||
sections['code'] = project_coordinate_to_absolute_scale(corrected_code, screen_width=screen_size[0], screen_height=screen_size[1], coordinate_type=coordinate_type)
|
||||
# logger.warning(f"Extracted Code: {sections.get('code', 'None')}")
|
||||
|
||||
if 'code' not in sections:
|
||||
logger.error("Missing required action or code section")
|
||||
return None, None, {}
|
||||
|
||||
if 'action' not in sections: # TODO: new added
|
||||
sections['action'] = ""
|
||||
|
||||
return sections['action'], [sections['code']], sections
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error parsing response: {str(e)}\nInput string: {input_string}")
|
||||
return None, None, {}
|
||||
|
||||
|
||||
def correct_pyautogui_arguments(code: str) -> str:
|
||||
function_corrections = {
|
||||
'write': {
|
||||
'incorrect_args': ['text', 'content'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': 'message'
|
||||
},
|
||||
'press': {
|
||||
'incorrect_args': ['key', 'button'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': None
|
||||
},
|
||||
'hotkey': {
|
||||
'incorrect_args': ['key1', 'key2', 'keys'],
|
||||
'correct_args': [],
|
||||
'keyword_arg': None
|
||||
},
|
||||
}
|
||||
|
||||
lines = code.strip().split('\n')
|
||||
corrected_lines = []
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
match = re.match(r'(pyautogui\.(\w+))\((.*)\)', line)
|
||||
if match:
|
||||
full_func_call = match.group(1)
|
||||
func_name = match.group(2)
|
||||
args_str = match.group(3)
|
||||
|
||||
if func_name in function_corrections:
|
||||
func_info = function_corrections[func_name]
|
||||
args = split_args(args_str)
|
||||
corrected_args = []
|
||||
|
||||
for arg in args:
|
||||
arg = arg.strip()
|
||||
kwarg_match = re.match(r'(\w+)\s*=\s*(.*)', arg)
|
||||
if kwarg_match:
|
||||
arg_name = kwarg_match.group(1)
|
||||
arg_value = kwarg_match.group(2)
|
||||
|
||||
if arg_name in func_info['incorrect_args']:
|
||||
if func_info['keyword_arg']:
|
||||
corrected_args.append(f"{func_info['keyword_arg']}={arg_value}")
|
||||
else:
|
||||
corrected_args.append(arg_value)
|
||||
else:
|
||||
corrected_args.append(f'{arg_name}={arg_value}')
|
||||
else:
|
||||
corrected_args.append(arg)
|
||||
|
||||
corrected_args_str = ', '.join(corrected_args)
|
||||
corrected_line = f'{full_func_call}({corrected_args_str})'
|
||||
corrected_lines.append(corrected_line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
else:
|
||||
corrected_lines.append(line)
|
||||
|
||||
corrected_code = '\n'.join(corrected_lines)
|
||||
return corrected_code
|
||||
|
||||
def split_args(args_str: str) -> List[str]:
|
||||
args = []
|
||||
current_arg = ''
|
||||
within_string = False
|
||||
string_char = ''
|
||||
prev_char = ''
|
||||
for char in args_str:
|
||||
if char in ['"', "'"]:
|
||||
if not within_string:
|
||||
within_string = True
|
||||
string_char = char
|
||||
elif within_string and prev_char != '\\' and char == string_char:
|
||||
within_string = False
|
||||
if char == ',' and not within_string:
|
||||
args.append(current_arg)
|
||||
current_arg = ''
|
||||
else:
|
||||
current_arg += char
|
||||
prev_char = char
|
||||
if current_arg:
|
||||
args.append(current_arg)
|
||||
return args
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
max_aspect_ratio_allowed: Optional[float] = None,
|
||||
size_can_be_smaller_than_factor: bool = False,
|
||||
):
|
||||
"""Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
|
||||
"""
|
||||
if not size_can_be_smaller_than_factor and (height < factor or width < factor):
|
||||
raise ValueError(
|
||||
f"height:{height} or width:{width} must be larger than factor:{factor} "
|
||||
f"(when size_can_be_smaller_than_factor is False)"
|
||||
)
|
||||
elif max_aspect_ratio_allowed is not None and max(height, width) / min(height, width) > max_aspect_ratio_allowed:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {max_aspect_ratio_allowed}, "
|
||||
f"got {max(height, width) / min(height, width)}"
|
||||
f"(when max_aspect_ratio_allowed is not None)"
|
||||
)
|
||||
h_bar = max(1, round(height / factor)) * factor
|
||||
w_bar = max(1, round(width / factor)) * factor
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(1, math.floor(height / beta / factor)) * factor
|
||||
w_bar = max(1, math.floor(width / beta / factor)) * factor
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = math.ceil(height * beta / factor) * factor
|
||||
w_bar = math.ceil(width * beta / factor) * factor
|
||||
return h_bar, w_bar
|
||||
|
||||
def _coordinate_projection(x, y, screen_width, screen_height, coordinate_type):
|
||||
if coordinate_type == "relative":
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
elif coordinate_type == "absolute":
|
||||
return x, y
|
||||
elif coordinate_type == "qwen25":
|
||||
if 0 <= x <= 1 and 0 <= y <= 1:
|
||||
# If already normalized, treat like "relative"
|
||||
return int(round(x * screen_width)), int(round(y * screen_height))
|
||||
|
||||
height, width = smart_resize(
|
||||
height=screen_height,
|
||||
width=screen_width,
|
||||
factor=28,
|
||||
min_pixels=3136,
|
||||
max_pixels=12845056
|
||||
)
|
||||
return int(x / width * screen_width), int(y / height * screen_height)
|
||||
elif coordinate_type == "relative1000":
|
||||
if screen_width == 0 or screen_height == 0:
|
||||
raise ValueError("Screen width and height must be greater than zero for relative1000 coordinates.")
|
||||
x_abs = int(round(x * screen_width / 1000))
|
||||
y_abs = int(round(y * screen_height / 1000))
|
||||
return x_abs, y_abs
|
||||
else:
|
||||
raise ValueError(f"Unsupported coordinate type: {coordinate_type}")
|
||||
|
||||
def project_coordinate_to_absolute_scale(pyautogui_code_relative_coordinates, screen_width, screen_height, coordinate_type="relative"):
|
||||
"""
|
||||
Convert the relative coordinates in the pyautogui code to absolute coordinates based on the logical screen size.
|
||||
"""
|
||||
import re
|
||||
import ast
|
||||
|
||||
if coordinate_type not in ["relative", "relative1000", "absolute", "qwen25"]:
|
||||
raise ValueError(f"Invalid coordinate type: {coordinate_type}. Expected one of ['relative', 'relative1000', 'absolute', 'qwen25'].")
|
||||
|
||||
pattern = r'(pyautogui\.\w+\([^\)]*\))'
|
||||
matches = re.findall(pattern, pyautogui_code_relative_coordinates)
|
||||
|
||||
new_code = pyautogui_code_relative_coordinates
|
||||
|
||||
for full_call in matches:
|
||||
func_name_pattern = r'(pyautogui\.\w+)\((.*)\)'
|
||||
func_match = re.match(func_name_pattern, full_call, re.DOTALL)
|
||||
if not func_match:
|
||||
continue
|
||||
|
||||
func_name = func_match.group(1)
|
||||
args_str = func_match.group(2)
|
||||
|
||||
try:
|
||||
parsed = ast.parse(f"func({args_str})").body[0].value
|
||||
parsed_args = parsed.args
|
||||
parsed_keywords = parsed.keywords
|
||||
except SyntaxError:
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
function_parameters = {
|
||||
'click': ['x', 'y', 'clicks', 'interval', 'button', 'duration', 'pause'],
|
||||
'moveTo': ['x', 'y', 'duration', 'tween', 'pause'],
|
||||
'moveRel': ['xOffset', 'yOffset', 'duration', 'tween', 'pause'],
|
||||
'dragTo': ['x', 'y', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
'dragRel': ['xOffset', 'yOffset', 'duration', 'button', 'mouseDownUp', 'pause'],
|
||||
'doubleClick': ['x', 'y', 'interval', 'button', 'duration', 'pause'],
|
||||
}
|
||||
|
||||
func_base_name = func_name.split('.')[-1]
|
||||
|
||||
param_names = function_parameters.get(func_base_name, [])
|
||||
|
||||
args = {}
|
||||
for idx, arg in enumerate(parsed_args):
|
||||
if idx < len(param_names):
|
||||
param_name = param_names[idx]
|
||||
arg_value = ast.literal_eval(arg)
|
||||
args[param_name] = arg_value
|
||||
|
||||
try:
|
||||
for kw in parsed_keywords:
|
||||
param_name = kw.arg
|
||||
arg_value = ast.literal_eval(kw.value)
|
||||
args[param_name] = arg_value
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing keyword arguments: {e}")
|
||||
return pyautogui_code_relative_coordinates
|
||||
|
||||
updated = False
|
||||
if 'x' in args and 'y' in args:
|
||||
try:
|
||||
x_rel = float(args['x'])
|
||||
y_rel = float(args['y'])
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
logger.warning(f"Projecting coordinates: ({x_rel}, {y_rel}) to ({x_abs}, {y_abs}) using {coordinate_type} projection.")
|
||||
args['x'] = x_abs
|
||||
args['y'] = y_abs
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if 'xOffset' in args and 'yOffset' in args:
|
||||
try:
|
||||
x_rel = float(args['xOffset'])
|
||||
y_rel = float(args['yOffset'])
|
||||
x_abs, y_abs = _coordinate_projection(x_rel, y_rel, screen_width, screen_height, coordinate_type)
|
||||
args['xOffset'] = x_abs
|
||||
args['yOffset'] = y_abs
|
||||
updated = True
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if updated:
|
||||
reconstructed_args = []
|
||||
for idx, param_name in enumerate(param_names):
|
||||
if param_name in args:
|
||||
arg_value = args[param_name]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"'{arg_value}'"
|
||||
else:
|
||||
arg_repr = str(arg_value)
|
||||
reconstructed_args.append(arg_repr)
|
||||
else:
|
||||
break
|
||||
|
||||
used_params = set(param_names[:len(reconstructed_args)])
|
||||
for kw in parsed_keywords:
|
||||
if kw.arg not in used_params:
|
||||
arg_value = args[kw.arg]
|
||||
if isinstance(arg_value, str):
|
||||
arg_repr = f"{kw.arg}='{arg_value}'"
|
||||
else:
|
||||
arg_repr = f"{kw.arg}={arg_value}"
|
||||
reconstructed_args.append(arg_repr)
|
||||
|
||||
new_args_str = ', '.join(reconstructed_args)
|
||||
new_full_call = f"{func_name}({new_args_str})"
|
||||
new_code = new_code.replace(full_call, new_full_call)
|
||||
|
||||
return new_code
|
||||
|
||||
def extract_positions_and_instructions(code, action) -> list[dict]:
|
||||
"""
|
||||
Extracts all `(x, y)` coordinates (both positional and keyword arguments)
|
||||
and their associated preceding comments as instructions from Python code.
|
||||
If there are no comments, use the corresponding action instead.
|
||||
|
||||
Args:
|
||||
code (str): The Python code as a string.
|
||||
action (str): The low-level action as a string.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of dictionaries with extracted positions and instructions.
|
||||
- function (str): The pyautogui function name.
|
||||
- x (int or float): The x-coordinate.
|
||||
- y (int or float): The y-coordinate.
|
||||
- instruction (str): The preceding comment as an instruction.
|
||||
"""
|
||||
lines = code.splitlines()
|
||||
extracted = []
|
||||
preceding_comment = action # To store the preceding comment
|
||||
|
||||
for line in lines:
|
||||
preceding_comment = action
|
||||
# Check if the line is a comment and store it
|
||||
if line.strip().startswith("#"):
|
||||
preceding_comment = line.strip().lstrip("#").strip() # Clean the comment
|
||||
|
||||
# Match pyautogui functions with positional arguments
|
||||
match_positional = re.match(r"(pyautogui\.\w+)\((\d+(\.\d+)?),\s*(\d+(\.\d+)?).*?\)", line)
|
||||
if match_positional:
|
||||
extracted.append({
|
||||
"function": match_positional.group(1), # pyautogui function name
|
||||
"x": float(match_positional.group(2)) if '.' in match_positional.group(2)\
|
||||
else int(match_positional.group(2)), # x-coordinate
|
||||
"y": float(match_positional.group(4)) if '.' in match_positional.group(4)\
|
||||
else int(match_positional.group(3)), # y-coordinate
|
||||
"instruction": preceding_comment, # Use the preceding comment
|
||||
})
|
||||
preceding_comment = None # Reset after associating it with a line
|
||||
continue
|
||||
|
||||
# Match pyautogui functions with keyword arguments
|
||||
match_keyword = re.match(r"(pyautogui\.\w+)\(.*?x=(\d+(\.\d+)?),\s*y=(\d+(\.\d+)?).*?\)", line)
|
||||
if match_keyword:
|
||||
extracted.append({
|
||||
"function": match_keyword.group(1), # pyautogui function name
|
||||
"x": float(match_keyword.group(2)) if '.' in match_keyword.group(2)\
|
||||
else int(match_keyword.group(2)), # x-coordinate
|
||||
"y": float(match_keyword.group(4)) if '.' in match_keyword.group(4)\
|
||||
else int(match_keyword.group(3)), # y-coordinate
|
||||
"instruction": preceding_comment, # Use the preceding comment
|
||||
})
|
||||
preceding_comment = None # Reset after associating it with a line
|
||||
|
||||
logger.info(f"Grounding extracted:\n{extracted}")
|
||||
return extracted
|
||||
|
||||
def update_code_with_new_coordinates(code, updated_positions):
|
||||
"""
|
||||
Replaces old `(x, y)` coordinates (both positional and keyword arguments)
|
||||
with updated ones in the code, handling multiple occurrences correctly.
|
||||
|
||||
Args:
|
||||
code (str): The original Python code as a string.
|
||||
updated_positions (list): A list of dictionaries with updated positions.
|
||||
|
||||
Returns:
|
||||
str: The updated Python code.
|
||||
"""
|
||||
# TODO: the matching logics in 'update_code_with_new_coordinates'
|
||||
# and 'extract_positions_and_instructions' are not exactly the same
|
||||
lines = code.splitlines()
|
||||
updated_code_lines = []
|
||||
position_index = 0 # Tracks which position update to use
|
||||
|
||||
for line in lines:
|
||||
if position_index < len(updated_positions):
|
||||
# Get the next update position
|
||||
update = updated_positions[position_index]
|
||||
function_pattern_positional = rf"{update['function']}\(\d+(\.\d+)?, \d+(\.\d+)?"
|
||||
function_pattern_keyword = rf"{update['function']}\(.*?x=\d+(\.\d+)?, y=\d+(\.\d+)?"
|
||||
|
||||
if re.search(function_pattern_positional, line):
|
||||
# Replace positional arguments
|
||||
line = re.sub(
|
||||
function_pattern_positional,
|
||||
f"{update['function']}({update['x']}, {update['y']}",
|
||||
line,
|
||||
count=1
|
||||
)
|
||||
position_index += 1 # Move to the next update
|
||||
elif re.search(function_pattern_keyword, line):
|
||||
# Replace keyword arguments
|
||||
line = re.sub(
|
||||
function_pattern_keyword,
|
||||
f"{update['function']}(x={update['x']}, y={update['y']}",
|
||||
line,
|
||||
count=1
|
||||
)
|
||||
position_index += 1 # Move to the next update
|
||||
|
||||
updated_code_lines.append(line)
|
||||
|
||||
return "\n".join(updated_code_lines)
|
||||
|
||||
def transform_agnet_action_to_code_block(action):
|
||||
if "computer.terminate" in action or "browser.select_option" in action or "browser.clear" in action:
|
||||
return f"```code\n{action}\n```"
|
||||
else:
|
||||
return f"```python\n{action}\n```"
|
||||
|
||||
class OpenCUAAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
history_type: str,
|
||||
max_image_history_length: int,
|
||||
|
||||
platform="ubuntu",
|
||||
|
||||
max_tokens=1500,
|
||||
top_p=0.9,
|
||||
temperature=0,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
cot_level: str = "l2",
|
||||
|
||||
screen_size=(1920, 1080),
|
||||
coordinate_type: str = "relative", # relative or qwen25
|
||||
|
||||
detail_history_length: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
self.platform = platform
|
||||
self.model = model
|
||||
assert self.model is not None, "Executor model cannot be None"
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.history_type = history_type
|
||||
self.coordinate_type = coordinate_type
|
||||
assert coordinate_type in ["relative", "relative1000", "absolute", "qwen25"]
|
||||
assert action_space in ["pyautogui"], "Invalid action space"
|
||||
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||
assert history_type in ["action_history", "thought_history", "observation_history"]
|
||||
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.cots = []
|
||||
|
||||
self.cot_level = cot_level
|
||||
self.screen_size = screen_size
|
||||
self.max_image_history_length = max_image_history_length
|
||||
self.detail_history_length = detail_history_length
|
||||
|
||||
if history_type == "action_history":
|
||||
self.HISTORY_TEMPLATE = ACTION_HISTORY_TEMPLATE
|
||||
elif history_type == "thought_history":
|
||||
self.HISTORY_TEMPLATE = THOUGHT_HISTORY_TEMPLATE
|
||||
elif history_type == "observation_history":
|
||||
self.HISTORY_TEMPLATE = OBSERVATION_HISTORY_TEMPLATE
|
||||
else:
|
||||
raise ValueError(f"Invalid history type: {history_type}")
|
||||
|
||||
def reset(self, _logger=None):
|
||||
global logger
|
||||
logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
|
||||
|
||||
self.observations = []
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.image_summaries = []
|
||||
|
||||
def _scale_scroll_for_windows(self, code: str, factor: int = 50) -> str:
|
||||
""" pyautogui.scroll has a different scale on Ubuntu and Windows, multiple 'factor' when scrolling on Windows system"""
|
||||
if self.platform.lower() != "windows":
|
||||
return code
|
||||
|
||||
pattern_pos = re.compile(r'(pyautogui\.scroll\()\s*([-+]?\d+)\s*\)')
|
||||
code = pattern_pos.sub(lambda m: f"{m.group(1)}{int(m.group(2))*factor})", code)
|
||||
return code
|
||||
|
||||
def predict(self, instruction: str, obs: Dict, **kwargs) -> List:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
"""
|
||||
if "step_idx" in kwargs:
|
||||
logger.info(f"========= {self.model} Step {kwargs['step_idx']} =======")
|
||||
else:
|
||||
logger.info(f"========================== {self.model} ===================================")
|
||||
logger.info(f"Instruction: \n{instruction}")
|
||||
|
||||
image_bytes = BytesIO(obs['screenshot'])
|
||||
with Image.open(image_bytes) as img:
|
||||
print("Actual screen size", img.size)
|
||||
print("Logical screen size", self.screen_size)
|
||||
|
||||
messages = []
|
||||
|
||||
if self.cot_level == "l3":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": AGNET_SYS_PROMPT_L3
|
||||
})
|
||||
elif self.cot_level == "l2":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": AGNET_SYS_PROMPT_L2
|
||||
})
|
||||
elif self.cot_level == "l1":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": AGNET_SYS_PROMPT_L1
|
||||
})
|
||||
elif self.cot_level == "l0":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": AGNET_SYS_PROMPT_L0
|
||||
})
|
||||
else:
|
||||
raise ValueError(f"Invalid COT level: {self.cot_level}")
|
||||
|
||||
instruction_prompt = INSTRUTION_TEMPLATE.format(instruction=instruction)
|
||||
|
||||
history_step_texts = []
|
||||
for i in range(len(self.actions)):
|
||||
if i > len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}"}
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
if self.detail_history_length > 0 and i >= len(self.actions) - self.detail_history_length:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + DETAIL_HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i]['action'],
|
||||
code=self.cots[i]['original_code']
|
||||
)
|
||||
else:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i]['action']
|
||||
)
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": history_content
|
||||
})
|
||||
else:
|
||||
history_content = STEP_TEMPLATE.format(step_num=i+1) + self.HISTORY_TEMPLATE.format(
|
||||
observation=self.cots[i].get('observation'),
|
||||
thought=self.cots[i].get('thought'),
|
||||
action=self.cots[i]['action']
|
||||
)
|
||||
history_step_texts.append(history_content)
|
||||
if i == len(self.actions) - self.max_image_history_length:
|
||||
messages.append({
|
||||
"role":"assistant",
|
||||
"content": "\n".join(history_step_texts)
|
||||
})
|
||||
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{encode_image(obs['screenshot'])}"}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction_prompt
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
# Print message structure if needed
|
||||
# logger.info("\nMessages structure:")
|
||||
# messages_to_print = []
|
||||
# current_image = 1
|
||||
# for msg in messages:
|
||||
# msg_copy = copy.deepcopy(msg)
|
||||
# if isinstance(msg_copy['content'], list):
|
||||
# for content in msg_copy['content']:
|
||||
# if content['type'] == 'image_url':
|
||||
# content['image_url']['url'] = f'Image {current_image}'
|
||||
# current_image += 1
|
||||
# messages_to_print.append(msg_copy)
|
||||
|
||||
# logger.info(json.dumps(messages_to_print, indent=2))
|
||||
|
||||
response = self.call_llm({
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature
|
||||
}, self.model)
|
||||
|
||||
logger.info(f"Model Output: \n\n{response}")
|
||||
if not response:
|
||||
logger.error("No response found in the response.")
|
||||
return response, [], {}
|
||||
|
||||
low_level_instruction, pyautogui_actions, other_cot = parse_response_to_cot_and_action(response, self.screen_size, self.coordinate_type)
|
||||
if not pyautogui_actions:
|
||||
logger.error("No pyautogui actions found in the response.")
|
||||
return response, [], {}
|
||||
|
||||
pyautogui_actions = [
|
||||
self._scale_scroll_for_windows(code) for code in pyautogui_actions
|
||||
]
|
||||
|
||||
self.observations.append(obs)
|
||||
logger.info(f"Parsed Low-level Action: \n{low_level_instruction}")
|
||||
logger.info(f"Parsed pyautogui Action: \n{pyautogui_actions}")
|
||||
|
||||
self.actions.append(low_level_instruction)
|
||||
self.cots.append(other_cot)
|
||||
|
||||
return response, pyautogui_actions, {}
|
||||
# return response, [parsed_action]
|
||||
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
# here you should add more model exceptions as you want,
|
||||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||
# because we want to catch this kind of Exception in the outside to ensure
|
||||
# each example won't exceed the time limit
|
||||
(
|
||||
Exception
|
||||
),
|
||||
interval=30,
|
||||
max_tries=10
|
||||
)
|
||||
def call_llm(self, payload, model):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.environ['OPENCUA_API_KEY']}"
|
||||
}
|
||||
|
||||
for _ in range(30):
|
||||
response = httpx.post(
|
||||
os.environ['OPENCUA_URL'],
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=500,
|
||||
verify=False
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error("Failed to call LLM: " + response.text)
|
||||
logger.error("Retrying...")
|
||||
time.sleep(5)
|
||||
else:
|
||||
return response.json()['choices'][0]['message']['content']
|
||||
528
run_multienv_opencua.py
Normal file
528
run_multienv_opencua.py
Normal file
@@ -0,0 +1,528 @@
|
||||
from __future__ import annotations
|
||||
import argparse
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import time
|
||||
from typing import List, Dict
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from multiprocessing import Process, Manager
|
||||
import lib_run_single
|
||||
from desktop_env.desktop_env import DesktopEnv
|
||||
from mm_agents.opencua_agent import OpenCUAAgent
|
||||
|
||||
# Global variables for signal handling
|
||||
active_environments = []
|
||||
processes = []
|
||||
is_terminating = False
|
||||
|
||||
# import wandb
|
||||
|
||||
# load the environment variables from .env file
|
||||
if os.path.exists(".env"):
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# Logger Configs {{{ #
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
)
|
||||
|
||||
# environment config
|
||||
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--headless", action="store_true", help="Run in headless machine"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observation_type",
|
||||
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
default="screenshot",
|
||||
help="Observation type",
|
||||
)
|
||||
parser.add_argument("--screen_width", type=int, default=1920)
|
||||
parser.add_argument("--screen_height", type=int, default=1080)
|
||||
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||
parser.add_argument("--max_steps", type=int, default=15)
|
||||
|
||||
# agent config
|
||||
parser.add_argument("--cot_level", type=str, default="l2", help="CoT version: l0, l1, l2, l3")
|
||||
parser.add_argument("--history_type", type=str, default="action_history", help="History: action history")
|
||||
parser.add_argument("--coordinate_type", type=str, default="relative", help="type of coordinate", choices=["relative", "qwen25"])
|
||||
parser.add_argument("--max_image_history_length", type=int, default=3)
|
||||
parser.add_argument("--detail_history_length", type=int, default=0, help="length of detail history")
|
||||
|
||||
# evaluation config
|
||||
parser.add_argument(
|
||||
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||
)
|
||||
|
||||
# lm config
|
||||
parser.add_argument("--model", type=str, default="opencua")
|
||||
parser.add_argument("--temperature", type=float, default=0)
|
||||
parser.add_argument("--top_p", type=float, default=0.9)
|
||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||
parser.add_argument("--stop_token", type=str, default=None)
|
||||
|
||||
# example config
|
||||
parser.add_argument("--domain", type=str, default="all")
|
||||
parser.add_argument(
|
||||
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||
)
|
||||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="./results")
|
||||
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
default='INFO', help="Set the logging level")
|
||||
# aws config
|
||||
parser.add_argument(
|
||||
"--region", type=str, default="us-east-1", help="AWS region for the VM"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
args = config() # Get command line arguments first
|
||||
|
||||
logger = logging.getLogger()
|
||||
log_level = getattr(logging, args.log_level.upper())
|
||||
logger.setLevel(log_level)
|
||||
|
||||
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||
|
||||
file_handler = logging.FileHandler(
|
||||
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
debug_handler = logging.FileHandler(
|
||||
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||
)
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
file_handler.setLevel(logging.INFO)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setLevel(log_level)
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
debug_handler.setFormatter(formatter)
|
||||
stdout_handler.setFormatter(formatter)
|
||||
|
||||
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.addHandler(stdout_handler)
|
||||
# }}} Logger Configs #
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
|
||||
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||
"""Distribute tasks evenly across environments."""
|
||||
# Flatten the tasks into a single list
|
||||
all_tasks = []
|
||||
for domain, examples in test_all_meta.items():
|
||||
for example_id in examples:
|
||||
all_tasks.append((domain, example_id))
|
||||
|
||||
# Calculate tasks per environment
|
||||
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||
|
||||
# Distribute tasks
|
||||
distributed_tasks = []
|
||||
for i in range(num_envs):
|
||||
env_tasks = {}
|
||||
start_idx = i * tasks_per_env
|
||||
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||
|
||||
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||
if domain not in env_tasks:
|
||||
env_tasks[domain] = []
|
||||
env_tasks[domain].append(example_id)
|
||||
|
||||
distributed_tasks.append(env_tasks)
|
||||
|
||||
return distributed_tasks
|
||||
|
||||
|
||||
def process_signal_handler(signum, frame, env_idx):
|
||||
"""Signal handler for child processes to gracefully shut down their environments."""
|
||||
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
|
||||
|
||||
# Get the active_environments from the caller's frame
|
||||
local_vars = frame.f_locals
|
||||
active_environments = local_vars.get('active_environments', [])
|
||||
|
||||
# Close environment in the current process context
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Process {env_idx + 1} closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
|
||||
|
||||
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||
"""Run tasks for a single environment."""
|
||||
# Each process has its own list of active environments
|
||||
active_environments = []
|
||||
env = None
|
||||
|
||||
# Setup signal handlers for this process too
|
||||
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
||||
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
|
||||
|
||||
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
|
||||
REGION = "us-east-1"
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path_to_vm,
|
||||
action_space=args.action_space,
|
||||
|
||||
provider_name="aws",
|
||||
region=REGION,
|
||||
snapshot_name=IMAGE_ID_MAP[REGION],
|
||||
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
headless=args.headless,
|
||||
os_type="Ubuntu",
|
||||
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||
)
|
||||
active_environments.append(env)
|
||||
agent = OpenCUAAgent(
|
||||
env=env,
|
||||
model=args.model,
|
||||
max_tokens=args.max_tokens,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
action_space=args.action_space,
|
||||
observation_type=args.observation_type,
|
||||
cot_level=args.cot_level,
|
||||
history_type=args.history_type,
|
||||
screen_size=(args.screen_width, args.screen_height),
|
||||
coordinate_type=args.coordinate_type,
|
||||
max_image_history_length=args.max_image_history_length,
|
||||
detail_history_length=args.detail_history_length,
|
||||
)
|
||||
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||
|
||||
try:
|
||||
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||
config_file = os.path.join(
|
||||
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||
)
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
|
||||
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||
|
||||
example_result_dir = os.path.join(
|
||||
args.result_dir,
|
||||
args.action_space,
|
||||
args.observation_type,
|
||||
args.model,
|
||||
domain,
|
||||
example_id,
|
||||
)
|
||||
os.makedirs(example_result_dir, exist_ok=True)
|
||||
|
||||
try:
|
||||
lib_run_single.run_single_example_opencua(
|
||||
agent,
|
||||
env,
|
||||
example,
|
||||
args.max_steps,
|
||||
example["instruction"],
|
||||
args,
|
||||
example_result_dir,
|
||||
shared_scores,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
env.controller.end_recording(
|
||||
os.path.join(example_result_dir, "recording.mp4")
|
||||
)
|
||||
except Exception as rec_e:
|
||||
logger.error(f"Failed to end recording: {rec_e}")
|
||||
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(
|
||||
json.dumps(
|
||||
{"Error": f"{domain}/{example_id} - {e}"}
|
||||
)
|
||||
)
|
||||
f.write("\n")
|
||||
finally:
|
||||
# This ensures the environment is closed even if there's an exception
|
||||
logger.info(f"Process {env_idx + 1} cleaning up environment...")
|
||||
try:
|
||||
env.close()
|
||||
logger.info(f"Process {env_idx + 1} environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
|
||||
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
|
||||
global is_terminating, active_environments, processes
|
||||
|
||||
# Avoid duplicate handling
|
||||
if is_terminating:
|
||||
return
|
||||
|
||||
is_terminating = True
|
||||
logger.info(f"Received signal {signum}. Gracefully shutting down...")
|
||||
|
||||
# Close all registered environments in the main process
|
||||
for env in active_environments:
|
||||
try:
|
||||
logger.info(f"Closing environment...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing environment: {e}")
|
||||
|
||||
# Send termination signal to all child processes first
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Sending termination signal to process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending termination signal to process: {e}")
|
||||
|
||||
# Allow a short time for processes to handle their own cleanup
|
||||
time.sleep(1)
|
||||
|
||||
# Forcefully terminate any processes that didn't exit
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Forcefully terminating process {p.name}...")
|
||||
import signal
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
except Exception as e:
|
||||
logger.error(f"Error forcefully terminating process: {e}")
|
||||
|
||||
logger.info("Shutdown complete. Exiting.")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||
global processes
|
||||
logger.info("Args: %s", args)
|
||||
|
||||
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||
|
||||
logger.info("All environments are ready. Starting parallel task execution...")
|
||||
|
||||
# Create a shared list for scores across processes
|
||||
with Manager() as manager:
|
||||
shared_scores = manager.list()
|
||||
|
||||
# Create and start processes for each environment
|
||||
processes = []
|
||||
for env_idx, env_tasks in enumerate(distributed_tasks):
|
||||
p = Process(
|
||||
target=run_env_tasks,
|
||||
args=(env_idx, env_tasks, args, shared_scores)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
logger.info(f"Started process {p.name} with PID {p.pid}")
|
||||
|
||||
try:
|
||||
# Wait for all processes to complete
|
||||
for p in processes:
|
||||
p.join()
|
||||
logger.info(f"Process {p.name} completed")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
|
||||
# Let the signal handler do the cleanup
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
|
||||
# Ensure cleanup happens
|
||||
for p in processes:
|
||||
if p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name} due to error...")
|
||||
p.terminate()
|
||||
except Exception as term_e:
|
||||
logger.error(f"Error terminating process {p.name}: {term_e}")
|
||||
raise
|
||||
|
||||
# Convert shared list to regular list
|
||||
scores = list(shared_scores)
|
||||
|
||||
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||
|
||||
|
||||
def get_unfinished(
|
||||
action_space, use_model, observation_type, result_dir, total_file_json
|
||||
):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
|
||||
if not os.path.exists(target_dir):
|
||||
return total_file_json
|
||||
|
||||
finished = {}
|
||||
for domain in os.listdir(target_dir):
|
||||
finished[domain] = []
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
if example_id == "onboard":
|
||||
continue
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" not in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
for file in os.listdir(example_path):
|
||||
os.remove(os.path.join(example_path, file))
|
||||
else:
|
||||
finished[domain].append(example_id)
|
||||
|
||||
if not finished:
|
||||
return total_file_json
|
||||
|
||||
for domain, examples in finished.items():
|
||||
if domain in total_file_json:
|
||||
total_file_json[domain] = [
|
||||
x for x in total_file_json[domain] if x not in examples
|
||||
]
|
||||
|
||||
return total_file_json
|
||||
|
||||
|
||||
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||
if not os.path.exists(target_dir):
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
|
||||
all_result = []
|
||||
|
||||
for domain in os.listdir(target_dir):
|
||||
domain_path = os.path.join(target_dir, domain)
|
||||
if os.path.isdir(domain_path):
|
||||
for example_id in os.listdir(domain_path):
|
||||
example_path = os.path.join(domain_path, example_id)
|
||||
if os.path.isdir(example_path):
|
||||
if "result.txt" in os.listdir(example_path):
|
||||
# empty all files under example_id
|
||||
try:
|
||||
all_result.append(
|
||||
float(
|
||||
open(
|
||||
os.path.join(example_path, "result.txt"), "r"
|
||||
).read()
|
||||
)
|
||||
)
|
||||
except:
|
||||
all_result.append(0.0)
|
||||
|
||||
if not all_result:
|
||||
print("New experiment, no result yet.")
|
||||
return None
|
||||
else:
|
||||
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||
return all_result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### The complete version of the list of examples #######
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Register signal handlers for graceful termination
|
||||
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
|
||||
|
||||
try:
|
||||
args = config()
|
||||
|
||||
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||
test_all_meta = json.load(f)
|
||||
|
||||
if args.domain != "all":
|
||||
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||
|
||||
test_file_list = get_unfinished(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
left_info = ""
|
||||
for domain in test_file_list:
|
||||
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||
logger.info(f"Left tasks:\n{left_info}")
|
||||
|
||||
get_result(
|
||||
args.action_space,
|
||||
args.model,
|
||||
args.observation_type,
|
||||
args.result_dir,
|
||||
test_all_meta,
|
||||
)
|
||||
test(args, test_file_list)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Main process received KeyboardInterrupt.")
|
||||
# Signal handler will take care of cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
|
||||
# Also trigger cleanup for unhandled exceptions
|
||||
signal_handler(signal.SIGTERM, None)
|
||||
finally:
|
||||
# Final cleanup in case any environments or processes remain
|
||||
logger.info("Main process final cleanup...")
|
||||
for env in active_environments:
|
||||
if env is not None:
|
||||
try:
|
||||
logger.info(f"Closing environment in final cleanup...")
|
||||
env.close()
|
||||
logger.info(f"Environment closed successfully in final cleanup")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during final environment cleanup: {e}")
|
||||
|
||||
# First try gentle termination
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Terminating process {p.name}...")
|
||||
p.terminate()
|
||||
except Exception as e:
|
||||
logger.error(f"Error terminating process: {e}")
|
||||
|
||||
# Wait a moment for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Then force kill if needed
|
||||
for p in processes:
|
||||
if p is not None and p.is_alive():
|
||||
try:
|
||||
logger.info(f"Force killing process {p.name}...")
|
||||
os.kill(p.pid, signal.SIGKILL)
|
||||
logger.info(f"Process {p.name} force killed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error force killing process: {e}")
|
||||
Reference in New Issue
Block a user