""" Script to run EvoCUA native agent model on OSWorld tasks. export AWS_ACCESS_KEY_ID="xx" export AWS_SECRET_ACCESS_KEY="xx" export AWS_REGION="xx" export AWS_SECURITY_GROUP_ID="xx" export AWS_SUBNET_ID="xx" export OPENAI_API_KEY="xxxx" export OPENAI_BASE_URL="xxxx" Example Usage (S2): python3 run_multienv_evocua.py \ --headless \ --provider_name aws \ --observation_type screenshot \ --model EvoCUA-S2 \ --result_dir ./evocua_s2 \ --test_all_meta_path evaluation_examples/test_nogdrive.json \ --max_steps 50 \ --num_envs 30 \ --temperature 0.01 \ --max_history_turns 4 \ --coordinate_type relative \ --resize_factor 32 \ --prompt_style S2 Example Usage (S1): python3 run_multienv_evocua.py \ --headless \ --provider_name aws \ --observation_type screenshot \ --model EvoCUA-S1 \ --result_dir ./evocua_s1 \ --test_all_meta_path evaluation_examples/test_nogdrive.json \ --max_steps 50 \ --num_envs 30 \ --max_history_turns 3 \ --coordinate_type qwen25 \ --max_tokens 10240 \ --resize_factor 28 \ --prompt_style S1 """ from __future__ import annotations import argparse import datetime import json import logging import os import sys import signal import time from typing import List from multiprocessing import Process, Manager, Queue from multiprocessing import current_process import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.evocua.evocua_agent import EvoCUAAgent # Global variables for signal handling active_environments = [] processes = [] is_terminating = False # Thread-local storage for task context (works per-process in multiprocessing) import threading _task_context = threading.local() def get_task_context(): """Get current task context from thread-local storage.""" return getattr(_task_context, 'context', {'domain': None, 'example_id': None}) def set_task_context(domain: str, example_id: str): """Set current task context in thread-local storage.""" _task_context.context = {'domain': domain, 'example_id': example_id} def clear_task_context(): """Clear current task context.""" if hasattr(_task_context, 'context'): delattr(_task_context, 'context') class TaskContextFilter(logging.Filter): """Filter to add domain and example_id to log records.""" def filter(self, record): ctx = get_task_context() domain = ctx.get('domain') example_id = ctx.get('example_id') if domain and example_id: record.domain = domain record.example_id = example_id # Add prefix to message if hasattr(record, 'msg') and isinstance(record.msg, str): if not record.msg.startswith(f"[{domain}/{example_id}]"): record.msg = f"[{domain}/{example_id}] {record.msg}" else: record.domain = domain or "N/A" record.example_id = example_id or "N/A" return True # load the environment variables from .env file if os.path.exists(".env"): from dotenv import load_dotenv load_dotenv() def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation with EvoCUAAgent" ) # environment config parser.add_argument("--path_to_vm", type=str, default=None) parser.add_argument( "--headless", action="store_true", help="Run in headless machine" ) parser.add_argument( "--action_space", type=str, default="pyautogui", help="Action type" ) parser.add_argument( "--observation_type", choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], default="screenshot", help="Observation type", ) parser.add_argument("--sleep_after_execution", type=float, default=5.0) parser.add_argument("--max_steps", type=int, default=50) # evaluation config parser.add_argument( "--test_config_base_dir", type=str, default="evaluation_examples" ) # lm config parser.add_argument("--model", type=str, default="evocua", help="Model name.") parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=32768) parser.add_argument("--stop_token", type=str, default=None) parser.add_argument("--prompt_style", type=str, default="S2", choices=["S1", "S2"], help="Prompt style: 'S1' (structured reasoning) or 'S2' (tool calling)") parser.add_argument("--history_type", type=str, default="action_history", help="[S1] History type") parser.add_argument("--coordinate_type", type=str, default="relative", help="Coordinate type: relative, absolute, qwen25") parser.add_argument("--password", type=str, default="osworld-public-evaluation", help="VM Password") # Unified History Parameter parser.add_argument("--max_history_turns", type=int, default=3, help="Number of history turns to include") parser.add_argument("--resize_factor", type=int, default=32, help="Image resize factor (S1: 28, S2: 32)") # example config parser.add_argument("--domain", type=str, default="all") parser.add_argument( "--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json" ) # logging related parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO', help="Set the logging level") # aws config parser.add_argument( "--region", type=str, default="us-east-1", help="AWS region for the VM" ) parser.add_argument( "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" ) parser.add_argument( "--client_password", type=str, default="", help="Client password" ) parser.add_argument( "--screen_width", type=int, default=1920, help="Screen width" ) parser.add_argument( "--screen_height", type=int, default=1080, help="Screen height" ) args = parser.parse_args() return args args = config() logger = logging.getLogger() log_level = getattr(logging, args.log_level.upper()) logger.setLevel(log_level) datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") file_handler = logging.FileHandler( os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" ) debug_handler = logging.FileHandler( os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" ) stdout_handler = logging.StreamHandler(sys.stdout) file_handler.setLevel(logging.INFO) debug_handler.setLevel(logging.DEBUG) stdout_handler.setLevel(log_level) formatter = logging.Formatter( fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" ) file_handler.setFormatter(formatter) debug_handler.setFormatter(formatter) stdout_handler.setFormatter(formatter) # Add task context filter to all handlers task_filter = TaskContextFilter() file_handler.addFilter(task_filter) debug_handler.addFilter(task_filter) stdout_handler.addFilter(task_filter) stdout_handler.addFilter(logging.Filter("desktopenv")) logger.addHandler(file_handler) logger.addHandler(debug_handler) logger.addHandler(stdout_handler) logger = logging.getLogger("desktopenv.experiment") def distribute_tasks(test_all_meta: dict) -> List[tuple]: all_tasks = [] for domain, examples in test_all_meta.items(): for example_id in examples: all_tasks.append((domain, example_id)) return all_tasks def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): active_environments = [] env = None try: REGION = args.region screen_size = (args.screen_width, args.screen_height) # Determine snapshot based on provider snapshot_name = "init_state" if args.provider_name == "aws": from desktop_env.providers.aws.manager import IMAGE_ID_MAP ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION].get((1920, 1080))) snapshot_name = ami_id env = DesktopEnv( path_to_vm=args.path_to_vm, action_space=args.action_space, provider_name=args.provider_name, region=REGION, snapshot_name=snapshot_name, screen_size=screen_size, headless=args.headless, os_type="Ubuntu", require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], enable_proxy=True, client_password=args.client_password ) active_environments.append(env) logger.info(f"Process {current_process().name} started.") while True: try: item = task_queue.get(timeout=5) except Exception: break domain, example_id = item set_task_context(domain, example_id) try: config_file = os.path.join( args.test_config_base_dir, f"examples/{domain}/{example_id}.json" ) with open(config_file, "r", encoding="utf-8") as f: example = json.load(f) logger.info(f"[{current_process().name}][Domain]: {domain}") logger.info(f"[{current_process().name}][Example ID]: {example_id}") logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") example_result_dir = os.path.join( args.result_dir, args.action_space, args.observation_type, args.model, domain, example_id, ) os.makedirs(example_result_dir, exist_ok=True) # Initialize EvoCUAAgent agent = EvoCUAAgent( model=args.model, max_tokens=args.max_tokens, top_p=args.top_p, temperature=args.temperature, action_space=args.action_space, observation_type=args.observation_type, max_steps=args.max_steps, prompt_style=args.prompt_style, max_history_turns=args.max_history_turns, screen_size=(args.screen_width, args.screen_height), coordinate_type=args.coordinate_type, password=args.password, resize_factor=args.resize_factor, ) try: lib_run_single.run_single_example_evocua( agent, env, example, args.max_steps, example["instruction"], args, example_result_dir, shared_scores, ) except Exception as e: import traceback logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") logger.error(traceback.format_exc()) try: env.controller.end_recording( os.path.join(example_result_dir, "recording.mp4") ) except Exception as rec_e: logger.error(f"Failed to end recording: {rec_e}") with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({"Error": f"{domain}/{example_id} - {e}"})) f.write("\n") except Exception as e: logger.error(f"Task-level error in {current_process().name}: {e}") import traceback logger.error(traceback.format_exc()) finally: clear_task_context() except Exception as e: logger.error(f"Process-level error in {current_process().name}: {e}") import traceback logger.error(traceback.format_exc()) finally: logger.info(f"{current_process().name} cleaning up environment...") try: if env: env.close() logger.info(f"{current_process().name} environment closed successfully") except Exception as e: logger.error(f"{current_process().name} error during environment cleanup: {e}") def signal_handler(signum, frame): """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" global is_terminating, active_environments, processes if is_terminating: return is_terminating = True logger.info(f"Received signal {signum}. Gracefully shutting down...") for env in active_environments: try: logger.info(f"Closing environment...") env.close() logger.info(f"Environment closed successfully") except Exception as e: logger.error(f"Error closing environment: {e}") for p in processes: if p.is_alive(): try: logger.info(f"Sending termination signal to process {p.name}...") p.terminate() except Exception as e: logger.error(f"Error sending termination signal to process: {e}") time.sleep(1) for p in processes: if p.is_alive(): try: logger.info(f"Forcefully terminating process {p.name}...") import signal as sig os.kill(p.pid, sig.SIGKILL) except Exception as e: logger.error(f"Error forcefully terminating process: {e}") logger.info("Shutdown complete. Exiting.") sys.exit(0) def test(args: argparse.Namespace, test_all_meta: dict) -> None: global processes logger.info("Args: %s", args) all_tasks = distribute_tasks(test_all_meta) logger.info(f"Total tasks: {len(all_tasks)}") with Manager() as manager: shared_scores = manager.list() task_queue = manager.Queue() for item in all_tasks: task_queue.put(item) num_envs = args.num_envs processes = [] for i in range(num_envs): p = Process( target=run_env_tasks, args=(task_queue, args, shared_scores), name=f"EnvProcess-{i+1}" ) p.daemon = True p.start() processes.append(p) logger.info(f"Started process {p.name} with PID {p.pid}") try: while True: alive_count = 0 for idx, p in enumerate(processes): if not p.is_alive(): logger.warning(f"Process {p.name} died, restarting...") new_p = Process( target=run_env_tasks, args=(task_queue, args, shared_scores), name=f"EnvProcess-Restart-{idx+1}" ) new_p.daemon = True new_p.start() processes[idx] = new_p logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") else: alive_count += 1 if task_queue.empty(): logger.info("All tasks finished.") break if alive_count == 0: logger.error("All processes died, exiting.") break time.sleep(5) for p in processes: p.join() except KeyboardInterrupt: logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") raise except Exception as e: logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) for p in processes: if p.is_alive(): try: logger.info(f"Terminating process {p.name} due to error...") p.terminate() except Exception as term_e: logger.error(f"Error terminating process {p.name}: {term_e}") raise scores = list(shared_scores) logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") def get_unfinished( action_space, use_model, observation_type, result_dir, total_file_json ): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) if not os.path.exists(target_dir): return total_file_json finished = {} for domain in os.listdir(target_dir): finished[domain] = [] domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): for example_id in os.listdir(domain_path): if example_id == "onboard": continue example_path = os.path.join(domain_path, example_id) if os.path.isdir(example_path): if "result.txt" not in os.listdir(example_path): for file in os.listdir(example_path): os.remove(os.path.join(example_path, file)) else: finished[domain].append(example_id) if not finished: return total_file_json for domain, examples in finished.items(): if domain in total_file_json: total_file_json[domain] = [ x for x in total_file_json[domain] if x not in examples ] return total_file_json def get_result(action_space, use_model, observation_type, result_dir, total_file_json): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) if not os.path.exists(target_dir): print("New experiment, no result yet.") return None all_result = [] for domain in os.listdir(target_dir): domain_path = os.path.join(target_dir, domain) if os.path.isdir(domain_path): for example_id in os.listdir(domain_path): example_path = os.path.join(domain_path, example_id) if os.path.isdir(example_path): if "result.txt" in os.listdir(example_path): try: all_result.append( float( open( os.path.join(example_path, "result.txt"), "r" ).read() ) ) except: all_result.append(0.0) if not all_result: print("New experiment, no result yet.") return None else: print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") return all_result if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" # Register signal handlers signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: args = config() path_to_args = os.path.join( args.result_dir, args.action_space, args.observation_type, args.model, "args.json", ) os.makedirs(os.path.dirname(path_to_args), exist_ok=True) with open(path_to_args, "w", encoding="utf-8") as f: json.dump(vars(args), f, indent=4) with open(args.test_all_meta_path, "r", encoding="utf-8") as f: test_all_meta = json.load(f) if args.domain != "all": test_all_meta = {args.domain: test_all_meta[args.domain]} test_file_list = get_unfinished( args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta, ) left_info = "" for domain in test_file_list: left_info += f"{domain}: {len(test_file_list[domain])}\n" logger.info(f"Left tasks:\n{left_info}") get_result( args.action_space, args.model, args.observation_type, args.result_dir, test_all_meta, ) test(args, test_file_list) except KeyboardInterrupt: logger.info("Main process received KeyboardInterrupt.") except Exception as e: logger.error(f"Unexpected error in main process: {e}", exc_info=True) signal_handler(signal.SIGTERM, None) finally: logger.info("Main process final cleanup...") for env in active_environments: if env is not None: try: logger.info("Closing environment in final cleanup...") env.close() except Exception as e: logger.error(f"Error during final environment cleanup: {e}") for p in processes: if p is not None and p.is_alive(): try: p.terminate() except Exception as e: logger.error(f"Error terminating process: {e}") time.sleep(1) for p in processes: if p is not None and p.is_alive(): try: os.kill(p.pid, signal.SIGKILL) except Exception as e: logger.error(f"Error force killing process: {e}")