From 46c9407879b1e02d314efb17c84152f214eb8a12 Mon Sep 17 00:00:00 2001 From: yuanmengqi Date: Sun, 20 Jul 2025 07:57:27 +0000 Subject: [PATCH] Clean elder version of opencua experiment runner --- run_multienv_openaicua_old.py | 533 ---------------------------------- 1 file changed, 533 deletions(-) delete mode 100644 run_multienv_openaicua_old.py diff --git a/run_multienv_openaicua_old.py b/run_multienv_openaicua_old.py deleted file mode 100644 index c4eb18c..0000000 --- a/run_multienv_openaicua_old.py +++ /dev/null @@ -1,533 +0,0 @@ -from __future__ import annotations -import argparse -import datetime -import json -import logging -import os -import sys -import signal -import time -from typing import List, Dict -import math -from tqdm import tqdm -from multiprocessing import Process, Manager -import lib_run_single -from desktop_env.desktop_env import DesktopEnv -from mm_agents.openai_cua_agent import OpenAICUAAgent - -# Global variables for signal handling -active_environments = [] -processes = [] -is_terminating = False - -# import wandb - -# load the environment variables from .env file -if os.path.exists(".env"): - from dotenv import load_dotenv - load_dotenv() - -# Logger Configs {{{ # -def config() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run end-to-end evaluation on the benchmark" - ) - - # environment config - parser.add_argument("--path_to_vm", type=str, default=None) - parser.add_argument( - "--headless", action="store_true", help="Run in headless machine" - ) - parser.add_argument( - "--action_space", type=str, default="pyautogui", help="Action type" - ) - parser.add_argument( - "--observation_type", - choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], - default="screenshot", - help="Observation type", - ) - parser.add_argument("--sleep_after_execution", type=float, default=0.0) - parser.add_argument("--max_steps", type=int, default=15) - - # agent config - parser.add_argument("--max_trajectory_length", type=int, default=3) - parser.add_argument( - "--test_config_base_dir", type=str, default="evaluation_examples" - ) - - # lm config - parser.add_argument("--model", type=str, default="gpt-4o") - parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--max_tokens", type=int, default=1500) - parser.add_argument("--stop_token", type=str, default=None) - - # example config - parser.add_argument("--domain", type=str, default="all") - parser.add_argument( - "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" - ) - - # logging related - parser.add_argument("--result_dir", type=str, default="./results") - parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") - parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], - default='INFO', help="Set the logging level") - # aws config - parser.add_argument( - "--region", type=str, default="us-east-1", help="AWS region for the VM" - ) - parser.add_argument( - "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" - ) - parser.add_argument( - "--client_password", type=str, default="", help="Client password" - ) - parser.add_argument( - "--screen_width", type=int, default=1920, help="Screen width" - ) - parser.add_argument( - "--screen_height", type=int, default=1080, help="Screen height" - ) - args = parser.parse_args() - return args - -args = config() # Get command line arguments first - -logger = logging.getLogger() -log_level = getattr(logging, args.log_level.upper()) -logger.setLevel(log_level) - -datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - -file_handler = logging.FileHandler( - os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" -) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) -stdout_handler = logging.StreamHandler(sys.stdout) - -file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) -stdout_handler.setLevel(log_level) - -formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" -) -file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) -stdout_handler.setFormatter(formatter) - -stdout_handler.addFilter(logging.Filter("desktopenv")) - -logger.addHandler(file_handler) -logger.addHandler(debug_handler) -logger.addHandler(stdout_handler) -# }}} Logger Configs # - -logger = logging.getLogger("desktopenv.experiment") - - -def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: - """Distribute tasks evenly across environments.""" - # Flatten the tasks into a single list - all_tasks = [] - for domain, examples in test_all_meta.items(): - for example_id in examples: - all_tasks.append((domain, example_id)) - - # Calculate tasks per environment - tasks_per_env = math.ceil(len(all_tasks) / num_envs) - - # Distribute tasks - distributed_tasks = [] - for i in range(num_envs): - env_tasks = {} - start_idx = i * tasks_per_env - end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) - - for domain, example_id in all_tasks[start_idx:end_idx]: - if domain not in env_tasks: - env_tasks[domain] = [] - env_tasks[domain].append(example_id) - - distributed_tasks.append(env_tasks) - - return distributed_tasks - - -def process_signal_handler(signum, frame, env_idx): - """Signal handler for child processes to gracefully shut down their environments.""" - logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") - - # Get the active_environments from the caller's frame - local_vars = frame.f_locals - active_environments = local_vars.get('active_environments', []) - - # Close environment in the current process context - for env in active_environments: - if env is not None: - try: - logger.info(f"Process {env_idx + 1} closing environment...") - env.close() - logger.info(f"Process {env_idx + 1} environment closed successfully") - except Exception as e: - logger.error(f"Process {env_idx + 1} error closing environment: {e}") - - logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") - sys.exit(0) - - -def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list): - """Run tasks for a single environment.""" - # Each process has its own list of active environments - active_environments = [] - env = None - - # Setup signal handlers for this process too - signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) - signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) - - from desktop_env.providers.aws.manager import IMAGE_ID_MAP - REGION = args.region - screen_size = (args.screen_width, args.screen_height) - ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) - env = DesktopEnv( - path_to_vm=args.path_to_vm, - action_space=args.action_space, - provider_name=args.provider_name, - region=REGION, - snapshot_name=ami_id, - screen_size=screen_size, - headless=args.headless, - os_type="Ubuntu", - require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], - enable_proxy=True, - client_password=args.client_password - ) - active_environments.append(env) - agent = OpenAICUAAgent( - env=env, - model=args.model, - max_tokens=args.max_tokens, - top_p=args.top_p, - temperature=args.temperature, - action_space=args.action_space, - observation_type=args.observation_type, - max_trajectory_length=args.max_trajectory_length, - client_password=args.client_password, - provider_name=args.provider_name, - screen_width=args.screen_width, - screen_height=args.screen_height - ) - logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") - - try: - for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): - for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Env {env_idx+1}][Domain]: {domain}") - logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") - logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - - try: - lib_run_single.run_single_example_openaicua( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, - ) - except Exception as e: - import traceback - logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") - logger.error(traceback.format_exc()) - try: - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - except Exception as rec_e: - logger.error(f"Failed to end recording: {rec_e}") - - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"{domain}/{example_id} - {e}"} - ) - ) - f.write("\n") - finally: - # This ensures the environment is closed even if there's an exception - logger.info(f"Process {env_idx + 1} cleaning up environment...") - try: - env.close() - logger.info(f"Process {env_idx + 1} environment closed successfully") - except Exception as e: - logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}") - - -def signal_handler(signum, frame): - """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" - global is_terminating, active_environments, processes - - # Avoid duplicate handling - if is_terminating: - return - - is_terminating = True - logger.info(f"Received signal {signum}. Gracefully shutting down...") - - # Close all registered environments in the main process - for env in active_environments: - try: - logger.info(f"Closing environment...") - env.close() - logger.info(f"Environment closed successfully") - except Exception as e: - logger.error(f"Error closing environment: {e}") - - # Send termination signal to all child processes first - for p in processes: - if p.is_alive(): - try: - logger.info(f"Sending termination signal to process {p.name}...") - p.terminate() - except Exception as e: - logger.error(f"Error sending termination signal to process: {e}") - - # Allow a short time for processes to handle their own cleanup - time.sleep(1) - - # Forcefully terminate any processes that didn't exit - for p in processes: - if p.is_alive(): - try: - logger.info(f"Forcefully terminating process {p.name}...") - import signal - os.kill(p.pid, signal.SIGKILL) - except Exception as e: - logger.error(f"Error forcefully terminating process: {e}") - - logger.info("Shutdown complete. Exiting.") - sys.exit(0) - - -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - global processes - logger.info("Args: %s", args) - - distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) - - logger.info("All environments are ready. Starting parallel task execution...") - - # Create a shared list for scores across processes - with Manager() as manager: - shared_scores = manager.list() - - # Create and start processes for each environment - processes = [] - for env_idx, env_tasks in enumerate(distributed_tasks): - p = Process( - target=run_env_tasks, - args=(env_idx, env_tasks, args, shared_scores) - ) - processes.append(p) - p.start() - logger.info(f"Started process {p.name} with PID {p.pid}") - - try: - # Wait for all processes to complete - for p in processes: - p.join() - logger.info(f"Process {p.name} completed") - except KeyboardInterrupt: - logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") - # Let the signal handler do the cleanup - raise - except Exception as e: - logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) - # Ensure cleanup happens - for p in processes: - if p.is_alive(): - try: - logger.info(f"Terminating process {p.name} due to error...") - p.terminate() - except Exception as term_e: - logger.error(f"Error terminating process {p.name}: {term_e}") - raise - - # Convert shared list to regular list - scores = list(shared_scores) - - logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") - - -def get_unfinished( - action_space, use_model, observation_type, result_dir, total_file_json -): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) - - if not os.path.exists(target_dir): - return total_file_json - - finished = {} - for domain in os.listdir(target_dir): - finished[domain] = [] - domain_path = os.path.join(target_dir, domain) - if os.path.isdir(domain_path): - for example_id in os.listdir(domain_path): - if example_id == "onboard": - continue - example_path = os.path.join(domain_path, example_id) - if os.path.isdir(example_path): - if "result.txt" not in os.listdir(example_path): - # empty all files under example_id - for file in os.listdir(example_path): - os.remove(os.path.join(example_path, file)) - else: - finished[domain].append(example_id) - - if not finished: - return total_file_json - - for domain, examples in finished.items(): - if domain in total_file_json: - total_file_json[domain] = [ - x for x in total_file_json[domain] if x not in examples - ] - - return total_file_json - - -def get_result(action_space, use_model, observation_type, result_dir, total_file_json): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) - if not os.path.exists(target_dir): - print("New experiment, no result yet.") - return None - - all_result = [] - - for domain in os.listdir(target_dir): - domain_path = os.path.join(target_dir, domain) - if os.path.isdir(domain_path): - for example_id in os.listdir(domain_path): - example_path = os.path.join(domain_path, example_id) - if os.path.isdir(example_path): - if "result.txt" in os.listdir(example_path): - # empty all files under example_id - try: - all_result.append( - float( - open( - os.path.join(example_path, "result.txt"), "r" - ).read() - ) - ) - except: - all_result.append(0.0) - - if not all_result: - print("New experiment, no result yet.") - return None - else: - print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") - return all_result - - -if __name__ == "__main__": - ####### The complete version of the list of examples ####### - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Register signal handlers for graceful termination - signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal - - try: - args = config() - - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) - - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} - - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") - - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) - except KeyboardInterrupt: - logger.info("Main process received KeyboardInterrupt.") - # Signal handler will take care of cleanup - except Exception as e: - logger.error(f"Unexpected error in main process: {e}", exc_info=True) - # Also trigger cleanup for unhandled exceptions - signal_handler(signal.SIGTERM, None) - finally: - # Final cleanup in case any environments or processes remain - logger.info("Main process final cleanup...") - for env in active_environments: - if env is not None: - try: - logger.info(f"Closing environment in final cleanup...") - env.close() - logger.info(f"Environment closed successfully in final cleanup") - except Exception as e: - logger.error(f"Error during final environment cleanup: {e}") - - # First try gentle termination - for p in processes: - if p is not None and p.is_alive(): - try: - logger.info(f"Terminating process {p.name}...") - p.terminate() - except Exception as e: - logger.error(f"Error terminating process: {e}") - - # Wait a moment for processes to terminate - time.sleep(1) - - # Then force kill if needed - for p in processes: - if p is not None and p.is_alive(): - try: - logger.info(f"Force killing process {p.name}...") - os.kill(p.pid, signal.SIGKILL) - logger.info(f"Process {p.name} force killed") - except Exception as e: - logger.error(f"Error force killing process: {e}")