From 54a14cbc07647dfbcfe06e21f85ec445f949729d Mon Sep 17 00:00:00 2001 From: hanyullai <54882356+hanyullai@users.noreply.github.com> Date: Sat, 30 Aug 2025 11:10:53 +0800 Subject: [PATCH] fix multienv bug (#327) --- run_autoglm.py | 3 +- run_multienv_autoglm.py | 858 ++++++++++++++++++++-------------------- 2 files changed, 440 insertions(+), 421 deletions(-) diff --git a/run_autoglm.py b/run_autoglm.py index ef514a4..8b161c7 100644 --- a/run_autoglm.py +++ b/run_autoglm.py @@ -19,8 +19,9 @@ from requests.exceptions import SSLError from tqdm import tqdm import lib_run_single -from desktop_env.desktop_env import DesktopEnv as DesktopEnvBase +from desktop_env.desktop_env import MAX_RETRIES, DesktopEnv as DesktopEnvBase from mm_agents.autoglm import AutoGLMAgent +from typing import Optional, Dict, Any # Almost deprecated since it's not multi-env, use run_multienv_*.py instead diff --git a/run_multienv_autoglm.py b/run_multienv_autoglm.py index f6d74c6..5923b67 100644 --- a/run_multienv_autoglm.py +++ b/run_multienv_autoglm.py @@ -8,13 +8,9 @@ import json import logging import os import sys -import signal +import math +import ast import time -from typing import List -from multiprocessing import Process, Manager, current_process -import lib_run_single -from run_autoglm import DesktopEnv -from mm_agents.autoglm import AutoGLMAgent import backoff import httpx @@ -22,52 +18,75 @@ from openai import APIConnectionError, APIError, OpenAI, RateLimitError from requests.exceptions import SSLError from tqdm import tqdm -# Global variables for signal handling -active_environments = [] -processes = [] -is_terminating = False - -# .env -from dotenv import load_dotenv -load_dotenv() +import lib_run_single +from desktop_env.desktop_env import MAX_RETRIES, DesktopEnv as DesktopEnvBase +from mm_agents.autoglm import AutoGLMAgent +from typing import Optional, Dict, Any +from multiprocessing import Pool # Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler(os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8") +debug_handler = logging.FileHandler(os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8") +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler(os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8") + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + def config() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run end-to-end evaluation on the benchmark" - ) + parser = argparse.ArgumentParser(description="Run end-to-end evaluation on the benchmark") # environment config parser.add_argument("--path_to_vm", type=str) parser.add_argument( - "--headless", action="store_true", default=True, help="Run in headless machine" - ) - parser.add_argument( - "--action_space", type=str, default="autoglm_computer_use", help="Action type" + "--provider_name", + type=str, + default="docker", + help="Virtualization provider (vmware, docker, aws, azure, gcp, virtualbox)", ) + parser.add_argument("--headless", action="store_true", default=True, help="Run in headless machine") + parser.add_argument("--action_space", type=str, default="autoglm_computer_use", help="Action type") parser.add_argument( "--observation_type", choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], default="a11y_tree", help="Observation type", ) - parser.add_argument( - "--provider_name", type=str, default="docker", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" - ) - parser.add_argument( - "--screen_width", type=int, default=1920, help="Screen width" - ) - parser.add_argument( - "--screen_height", type=int, default=1080, help="Screen height" - ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) parser.add_argument("--sleep_after_execution", type=float, default=1.0) parser.add_argument("--max_steps", type=int, default=50) # agent config parser.add_argument("--max_trajectory_length", type=int, default=3) - parser.add_argument( - "--test_config_base_dir", type=str, default="evaluation_examples" - ) + parser.add_argument("--test_config_base_dir", type=str, default="evaluation_examples") # lm config parser.add_argument("--model", type=str, default="autoglm-os") @@ -78,331 +97,255 @@ def config() -> argparse.Namespace: # example config parser.add_argument("--domain", type=str, default="all") - parser.add_argument( - "--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json" - ) + parser.add_argument("--test_all_meta_path", type=str, default="evaluation_examples/test_nogdrive.json") # aws config parser.add_argument( "--region", type=str, default="us-east-1", help="AWS region for the VM" ) - parser.add_argument( - "--client_password", type=str, default="", help="Client password" - ) + parser.add_argument("--client_password", type=str, default="", help="Client password") # logging related parser.add_argument("--result_dir", type=str, default="./results") - parser.add_argument("--num_envs", type=int, default=20, help="Number of environments to run in parallel") - parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='INFO', help="Set the logging level") + # parallel number + parser.add_argument("--num_workers", type=int, default=20, help="Number of parallel workers") args = parser.parse_args() + return args -args = config() # Get command line arguments first -if args.client_password == "": - if args.provider_name == "aws": - args.client_password = "osworld-public-evaluation" - else: - args.client_password = "password" -else: - args.client_password = args.client_password -logger = logging.getLogger() -log_level = getattr(logging, args.log_level.upper()) -logger.setLevel(log_level) +class DesktopEnv(DesktopEnvBase): + def step(self, action, pause=2): + self._step_no += 1 + self.action_history.append(action) + + # Mark environment as used when step is called + self.is_environment_used = True -datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + reward = 0 # todo: Define reward calculation for each example + done = False # todo: Define episode termination condition for each example + info = {} + logger.info(f"Step {self._step_no} in trajectory {self._traj_no} with action: {action}") -file_handler = logging.FileHandler( - os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" -) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) -stdout_handler = logging.StreamHandler(sys.stdout) - -file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) -stdout_handler.setLevel(log_level) - -formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" -) -file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) -stdout_handler.setFormatter(formatter) - -stdout_handler.addFilter(logging.Filter("desktopenv")) - -logger.addHandler(file_handler) -logger.addHandler(debug_handler) -logger.addHandler(stdout_handler) -# }}} Logger Configs # - -logger = logging.getLogger("desktopenv.experiment") - - -def distribute_tasks(test_all_meta: dict) -> List[tuple]: - """Distribute tasks evenly across environments.""" - # Flatten the tasks into a single list - all_tasks = [] - for domain, examples in test_all_meta.items(): - for example_id in examples: - all_tasks.append((domain, example_id)) - - return all_tasks - - -def process_signal_handler(signum, frame, env_idx): - """Signal handler for child processes to gracefully shut down their environments.""" - logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") - - # Get the active_environments from the caller's frame - local_vars = frame.f_locals - active_environments = local_vars.get('active_environments', []) - - # Close environment in the current process context - for env in active_environments: - if env is not None: + # handle the special actions + if action in ['WAIT', 'FAIL', 'DONE']: + if action == 'WAIT': + time.sleep(pause) + exe_result = 'Wait ' + str(pause) + ' seconds' + elif action == 'FAIL': + done = True + info = {"fail": True} + exe_result = 'Finish: fail' + elif action == 'DONE': + done = True + info = {"done": True} + exe_result = 'Finish: success' + elif type(action) == dict: + if action['action_type'] == 'OPEN_APP': + self.setup_controller._launch_setup(action['parameters']['launch_app_command'], shell=True) + exe_result = 'Open ' + action['parameters']['app_name'] + elif action['action_type'] == 'OPEN_CHROME_TAB': + self.setup_controller._chrome_open_tabs_setup(action['parameters']['urls_to_open']) + exe_result = 'Open ' + str(action['parameters']['urls_to_open']) + ' in Chrome successfully' + else: + # the set of all possible python commands insides `pyautogui` + result = self.controller.execute_python_command(action) try: - logger.info(f"Process {env_idx + 1} closing environment...") - env.close() - logger.info(f"Process {env_idx + 1} environment closed successfully") + if result['error']: + exe_result = result['error'].strip() + else: + exe_result = result['output'].strip() except Exception as e: - logger.error(f"Process {env_idx + 1} error closing environment: {e}") - - logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") - sys.exit(0) + exe_result = 'Error Action: ' + action + logger.error(f"Error executing action: {e}") + time.sleep(pause) + observation = self._get_obs() + observation['exe_result'] = exe_result + + return observation, reward, done, info -def run_env_tasks(task_queue, args, shared_scores): - """Run tasks for a single environment.""" - active_environments = [] - env = None - try: - @backoff.on_exception( - backoff.constant, - (RateLimitError, APIConnectionError), - interval=0.1, - ) - def call_llm(messages): - logger.info("Calling LLM...") - # set api_key and base_url by environment variables - engine = OpenAI(timeout=60.0) - response = engine.chat.completions.create( - model=args.model, - messages=messages, - max_tokens=args.max_tokens, - temperature=args.temperature, - top_p=args.top_p, - ) - logger.info("LLM called successfully.") - return response.choices[0].message.content + def reset(self, task_config: Optional[Dict[str, Any]] = None, seed=None, options=None) -> Dict[str, Any]: + # Reset to certain task in OSWorld + logger.info("Resetting environment...") + logger.info("Switching task...") + logger.info("Setting counters...") + self._traj_no += 1 + self._step_no = 0 + self.action_history.clear() - env = DesktopEnv( - provider_name=args.provider_name, - region=args.region, - client_password=args.client_password, - path_to_vm=args.path_to_vm, - action_space=args.action_space, - screen_size=(args.screen_width, args.screen_height), - headless=args.headless, - os_type="Ubuntu", - require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], - ) - active_environments.append(env) - agent = AutoGLMAgent( - action_space=args.action_space, - observation_type=args.observation_type, - max_trajectory_length=args.max_trajectory_length, - client_password=args.client_password, - gen_func=call_llm, - ) - logger.info(f"Process {current_process().name} started.") - while True: - try: - item = task_queue.get(timeout=5) - except Exception: - break - domain, example_id = item - try: - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - logger.info(f"[{current_process().name}][Domain]: {domain}") - logger.info(f"[{current_process().name}][Example ID]: {example_id}") - logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - try: - lib_run_single.run_single_example_autoglm( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, + for attempt in range(MAX_RETRIES): + # Only revert to snapshot if environment has been used (step/setup) + # This optimization is especially important for cloud providers like AWS + # where unnecessary snapshot operations are costly and time-consuming + + if task_config is not None: + # Only consider task proxy requirement if proxy is enabled at system level + task_use_proxy = task_config.get("proxy", False) and self.enable_proxy + if not self.enable_proxy and task_config.get("proxy", False): + logger.info("Task requires proxy but proxy is disabled at system level, ignoring proxy requirement.") + + if task_use_proxy != self.current_use_proxy: + # keep because get_info_from_website depend on this + self.current_use_proxy = task_use_proxy + + if self.is_environment_used: + logger.info("Environment has been used, reverting to snapshot {}...".format(self.snapshot_name)) + self._revert_to_snapshot() + logger.info("Starting emulator...") + self._start_emulator() + logger.info("Emulator started.") + # Reset the usage flag after reverting + self.is_environment_used = False + else: + logger.info("Environment is clean, skipping snapshot revert (provider: {}).".format(self.provider_name)) + + if task_config is not None: + if task_config.get("proxy", False) and self.enable_proxy: + # If using proxy and proxy is enabled, set up the proxy configuration + self.setup_controller._proxy_setup(self.client_password) + self._set_task_info(task_config) + self.setup_controller.reset_cache_dir(self.cache_dir) + logger.info("Setting up environment...") + success = self.setup_controller.setup(self.config, task_config.get("proxy", False) and self.enable_proxy) + if success: + # Mark environment as used when setup is successfully executed + if self.config: # Only mark as used if there were actual setup operations + self.is_environment_used = True + break + else: + logger.error( + "Environment setup failed, retrying (%d/%d)...", + attempt + 1, + MAX_RETRIES, ) - except Exception as e: - import traceback - logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") - logger.error(traceback.format_exc()) - try: - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - except Exception as rec_e: - logger.error(f"Failed to end recording: {rec_e}") - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"{domain}/{example_id} - {e}"} - ) - ) - f.write("\n") - except Exception as e: - logger.error(f"Task-level error in {current_process().name}: {e}") - import traceback - logger.error(traceback.format_exc()) - except Exception as e: - logger.error(f"Process-level error in {current_process().name}: {e}") - import traceback - logger.error(traceback.format_exc()) - finally: - logger.info(f"{current_process().name} cleaning up environment...") - try: - if env: - env.close() - logger.info(f"{current_process().name} environment closed successfully") - except Exception as e: - logger.error(f"{current_process().name} error during environment cleanup: {e}") + time.sleep(5) + else: + break + + logger.info("Environment setup complete.") + # Upload tools from autoglm package + import mm_agents.autoglm + tool_dir = os.path.join(os.path.dirname(mm_agents.autoglm.__file__), 'tools', 'package') + for file in os.listdir(tool_dir): + if os.path.isdir(os.path.join(tool_dir, file)): + continue + self.setup_controller._upload_file_setup([{ + "local_path": os.path.join(tool_dir, file), + "path": os.path.join('~', file) + }]) -def signal_handler(signum, frame): - """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" - global is_terminating, active_environments, processes - - # Avoid duplicate handling - if is_terminating: - return - - is_terminating = True - logger.info(f"Received signal {signum}. Gracefully shutting down...") - - # Close all registered environments in the main process - for env in active_environments: - try: - logger.info(f"Closing environment...") - env.close() - logger.info(f"Environment closed successfully") - except Exception as e: - logger.error(f"Error closing environment: {e}") - - # Send termination signal to all child processes first - for p in processes: - if p.is_alive(): + # start soffice service for office tools + self.setup_controller._launch_setup('soffice --accept="socket,host=localhost,port=2002;urp;" --norestore --nologo --nodefault', shell=True) + time.sleep(5) + + observation = self._get_obs() + return observation + + def get_current_apps(self): + apps_code = r"""import subprocess; +command = "wmctrl -xl"; +apps = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip().split('\n'); +print(apps);""" + window_code = r"""import subprocess; +command = "wmctrl -a :ACTIVE: -v 2>&1 | grep 'Using window' | awk '{print $3}'"; +window_id = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip(); +print(window_id);""" + + apps = self.controller.execute_python_command(apps_code)['output'].strip() + apps = ast.literal_eval(apps) + app_list = {} + + for app in apps: + parts = app.split(maxsplit=4) + if len(parts) < 4: + continue + if parts[1] != '0': + continue + window_id = parts[0] + app_name = '.'.join(parts[2].split('.')[-(math.ceil(parts[2].count('.') / 2)):]) + title = parts[3] + app_list[window_id] = { + 'app_name': app_name, + 'title': title + } + + cur_id = self.controller.execute_python_command(window_code)['output'].strip() + + return app_list, cur_id + + def maximize_window(self): + window_state = r"""import subprocess; +command = "xprop -id $(xprop -root _NET_ACTIVE_WINDOW | awk -F' ' '{print $5}') _NET_WM_STATE" +output = subprocess.run(command, shell=True, capture_output=True, text=True).stdout.strip(); +print(output);""" + for _ in range(5): try: - logger.info(f"Sending termination signal to process {p.name}...") - p.terminate() + self.setup_controller._launch_setup('wmctrl -r :ACTIVE: -b add,maximized_vert,maximized_horz', shell=True) + time.sleep(2) + output = self.controller.execute_python_command(window_state)['output'].strip() + if '_NET_WM_STATE_FOCUSED' not in output or '_NET_WM_STATE_SKIP_TASKBAR' in output or '_NET_WM_STATE_MODAL' in output or '_NET_WM_STATE_MAXIMIZED' in output: # 没有窗口 or popups or 模态窗口 or 窗口已经最大化 + return except Exception as e: - logger.error(f"Error sending termination signal to process: {e}") - - # Allow a short time for processes to handle their own cleanup - time.sleep(1) - - # Forcefully terminate any processes that didn't exit - for p in processes: - if p.is_alive(): + logger.error(f"Failed to maximize window: {e}") + time.sleep(1) + + def _get_obs(self): + tool_list = { + "libreoffice_calc": "CalcTools", + "libreoffice_impress": "ImpressTools", + "libreoffice_writer": "WriterTools", + "code": "CodeTools", + "vlc": "VLCTools", + "google_chrome": "BrowserTools" + } + + self.maximize_window() + + for i in range(3): try: - logger.info(f"Forcefully terminating process {p.name}...") - import signal as sig - os.kill(p.pid, sig.SIGKILL) + app_list, cur_id = self.get_current_apps() except Exception as e: - logger.error(f"Error forcefully terminating process: {e}") - - logger.info("Shutdown complete. Exiting.") - sys.exit(0) + if i == 2: + raise e + logger.error(f"Failed to get current apps: {e}") + time.sleep(1) + + if cur_id in app_list: + cur_app = app_list[cur_id]['app_name'] + tool_name = cur_app.strip().lower().replace('-', '_') + if tool_name in tool_list: + class_name = tool_list[tool_name] + command = f"from {tool_name} import *; " + command += f"{class_name}.env_info(); " + command += f"{class_name}.print_result();" + app_info = self.controller.execute_python_command(command)['output'].strip() + else: + app_info = None + else: + cur_app = None + app_info = None + + tree = self.controller.get_accessibility_tree() + screenshot = self.controller.get_screenshot() + if screenshot is None: + logger.error("Failed to get screenshot.") + screenshot = b'' -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - global processes - logger.info("Args: %s", args) - all_tasks = distribute_tasks(test_all_meta) - logger.info(f"Total tasks: {len(all_tasks)}") - with Manager() as manager: - shared_scores = manager.list() - task_queue = manager.Queue() - for item in all_tasks: - task_queue.put(item) - num_envs = args.num_envs - processes = [] - for i in range(num_envs): - p = Process( - target=run_env_tasks, - args=(task_queue, args, shared_scores), - name=f"EnvProcess-{i+1}" - ) - p.daemon = True - p.start() - processes.append(p) - logger.info(f"Started process {p.name} with PID {p.pid}") - try: - while True: - alive_count = 0 - for idx, p in enumerate(processes): - if not p.is_alive(): - logger.warning(f"Process {p.name} died, restarting...") - new_p = Process( - target=run_env_tasks, - args=(task_queue, args, shared_scores), - name=f"EnvProcess-Restart-{idx+1}" - ) - new_p.daemon = True - new_p.start() - processes[idx] = new_p - logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") - else: - alive_count += 1 - if task_queue.empty(): - logger.info("All tasks finished.") - break - if alive_count == 0: - logger.error("All processes died, exiting.") - break - time.sleep(5) - for p in processes: - p.join() - except KeyboardInterrupt: - logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") - raise - except Exception as e: - logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) - for p in processes: - if p.is_alive(): - try: - logger.info(f"Terminating process {p.name} due to error...") - p.terminate() - except Exception as term_e: - logger.error(f"Error terminating process {p.name}: {term_e}") - raise - scores = list(shared_scores) - logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + return { + "screenshot": screenshot, + "accessibility_tree": tree, + "instruction": self.instruction, + "apps": app_list, + "cur_window_id": cur_id, + "cur_app": cur_app, + "app_info": app_info, + } - -def get_unfinished( - action_space, use_model, observation_type, result_dir, total_file_json -): +def get_unfinished(action_space, use_model, observation_type, result_dir, total_file_json): target_dir = os.path.join(result_dir, action_space, observation_type, use_model) if not os.path.exists(target_dir): @@ -430,9 +373,7 @@ def get_unfinished( for domain, examples in finished.items(): if domain in total_file_json: - total_file_json[domain] = [ - x for x in total_file_json[domain] if x not in examples - ] + total_file_json[domain] = [x for x in total_file_json[domain] if x not in examples] return total_file_json @@ -454,13 +395,7 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if "result.txt" in os.listdir(example_path): # empty all files under example_id try: - all_result.append( - float( - open( - os.path.join(example_path, "result.txt"), "r" - ).read() - ) - ) + all_result.append(float(open(os.path.join(example_path, "result.txt"), "r").read())) except: all_result.append(0.0) @@ -471,93 +406,176 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") return all_result +def _worker_run(task): + import json, os, datetime, logging, httpx, backoff + from openai import OpenAI, RateLimitError, APIConnectionError + from types import SimpleNamespace + domain, example_id, args = task # args 为 argparse.Namespace + logger = logging.getLogger("desktopenv.experiment") + try: + config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + instruction = example["instruction"] + + @backoff.on_exception(backoff.constant, (RateLimitError, APIConnectionError), interval=0.1) + def call_llm(messages): + logger.info("Calling LLM...") + # set api_key and base_url by environment variables + engine = OpenAI(timeout=60.0) + response = engine.chat.completions.create( + model=args.model, + messages=messages, + max_tokens=args.max_tokens, + temperature=args.temperature, + top_p=args.top_p, + ) + logger.info("LLM called successfully.") + return response.choices[0].message.content + + env = DesktopEnv( + provider_name=args.provider_name, + region=args.region, + client_password=args.client_password, + path_to_vm=args.path_to_vm, + action_space=args.action_space, + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + ) + agent = AutoGLMAgent( + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + client_password=args.client_password, + gen_func=call_llm, + ) + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + local_scores = [] + try: + lib_run_single.run_single_example_autoglm( + agent, + env, + example, + args.max_steps, + instruction, + args, + example_result_dir, + local_scores, + ) + except Exception as e: + logger.error(f"[并发任务异常] {domain}/{example_id}: {e}") + if hasattr(env, "controller") and env.controller is not None: + try: + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + except Exception: + pass + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({"Error": f"Exception in {domain}/{example_id}: {str(e)}"}) + "\n") + finally: + try: + env.close() + except Exception: + pass + + score = None + result_path = os.path.join(example_result_dir, "result.txt") + if os.path.exists(result_path): + try: + with open(result_path, "r") as rf: + score = float(rf.read().strip()) + except Exception: + score = 0.0 + else: + score = 0.0 + logger.info(f"[Finish] {domain}/{example_id} score={score}") + return (domain, example_id, score) + except Exception as e: + logger = logging.getLogger("desktopenv.experiment") + logger.error(f"[Initializing Fail] {domain}/{example_id}: {e}") + return (domain, example_id, 0.0) + +def test_parallel(args: argparse.Namespace, test_all_meta: dict): + from tqdm import tqdm + tasks = [] + for domain in test_all_meta: + for example_id in test_all_meta[domain]: + tasks.append((domain, example_id, args)) + if not tasks: + logger.info("No pending tasks") + return + logger.info(f"Starting parallel execution: {args.num_workers} processes, {len(tasks)} tasks total") + + results = [] + with Pool(processes=args.num_workers) as pool: + for res in tqdm(pool.imap_unordered(_worker_run, tasks), total=len(tasks), desc="Parallel execution"): + results.append(res) + + scores = [s for (_, _, s) in results if s is not None] + if scores: + avg = sum(scores) / len(scores) + logger.info(f"Parallel execution completed. Average score: {avg}") + else: + logger.info("No scores obtained.") if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" - - # Register signal handlers for graceful termination - signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C - signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal - - try: - # args already defined globally above - - # save args to json in result_dir/action_space/observation_type/model/args.json - path_to_args = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - "args.json", - ) - os.makedirs(os.path.dirname(path_to_args), exist_ok=True) - with open(path_to_args, "w", encoding="utf-8") as f: - json.dump(vars(args), f, indent=4) + args = config() + if args.client_password == "": + if args.provider_name == "aws": + args.client_password = "osworld-public-evaluation" + else: + args.client_password = "password" + else: + args.client_password = args.client_password - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) - except KeyboardInterrupt: - logger.info("Main process received KeyboardInterrupt.") - # Signal handler will take care of cleanup - except Exception as e: - logger.error(f"Unexpected error in main process: {e}", exc_info=True) - # Also trigger cleanup for unhandled exceptions - signal_handler(signal.SIGTERM, None) - finally: - # Final cleanup in case any environments or processes remain - logger.info("Main process final cleanup...") - for env in active_environments: - if env is not None: - try: - logger.info(f"Closing environment in final cleanup...") - env.close() - logger.info(f"Environment closed successfully in final cleanup") - except Exception as e: - logger.error(f"Error during final environment cleanup: {e}") - - # First try gentle termination - for p in processes: - if p is not None and p.is_alive(): - try: - logger.info(f"Terminating process {p.name}...") - p.terminate() - except Exception as e: - logger.error(f"Error terminating process: {e}") - - # Wait a moment for processes to terminate - time.sleep(1) - - # Then force kill if needed - for p in processes: - if p is not None and p.is_alive(): - try: - logger.info(f"Force killing process {p.name}...") - os.kill(p.pid, signal.SIGKILL) - logger.info(f"Process {p.name} force killed") - except Exception as e: - logger.error(f"Error force killing process: {e}") \ No newline at end of file + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test_parallel(args, test_file_list)