From 82c3cdd5900a7dd3cfefad885ed50cd854d0ec9c Mon Sep 17 00:00:00 2001 From: yuanmengqi Date: Tue, 22 Jul 2025 19:46:42 +0000 Subject: [PATCH] feat: refactor run_multienv_qwen25vl.py and qwen25vl_agent.py for improved logging and task management - Introduced signal handling for graceful shutdown of environments and processes. - Enhanced logging configuration to support dynamic log levels and structured output. - Updated argument parsing to include new parameters for model selection and task execution. - Refactored task distribution logic to streamline environment task management. - Improved error handling during task execution and environment cleanup. - Adjusted Qwen25VLAgent initialization to support new model and thought prefix options. - Reduced max tries for LLM calls to optimize performance. --- mm_agents/qwen25vl_agent.py | 29 +- run_multienv_qwen25vl.py | 561 +++++++++++++++++++++++------------- 2 files changed, 383 insertions(+), 207 deletions(-) diff --git a/mm_agents/qwen25vl_agent.py b/mm_agents/qwen25vl_agent.py index 20d30bc..dddad37 100644 --- a/mm_agents/qwen25vl_agent.py +++ b/mm_agents/qwen25vl_agent.py @@ -66,25 +66,24 @@ class Qwen25VLAgent: def __init__( self, platform="ubuntu", - planner_model="gpt-4o", - executor_model="qwen2.5vl", + model="qwen2.5-vl-72b-instruct", max_tokens=1500, top_p=0.9, temperature=0.5, action_space="pyautogui", observation_type="screenshot", history_n=4, # Number of previous interactions to include in full detail + add_thought_prefix=False, ): self.platform = platform - self.planner_model = planner_model - self.executor_model = executor_model - assert self.executor_model is not None, "Executor model cannot be None" + self.model = model self.max_tokens = max_tokens self.top_p = top_p self.temperature = temperature self.action_space = action_space self.observation_type = observation_type self.history_n = history_n # Control how many previous interactions to include + self.add_thought_prefix = add_thought_prefix assert action_space in ["pyautogui"], "Invalid action space" assert observation_type in ["screenshot"], "Invalid observation type" self.thoughts = [] @@ -277,19 +276,20 @@ Previous actions: }) # append_text = f"""Step {current_step+1}: Thought:""" - append_text = f"""Thought:""" - messages.append({"role": "assistant", "content": [{"type": "text", "text": append_text}]}) + if self.add_thought_prefix: + append_text = f"""Thought:""" + messages.append({"role": "assistant", "content": [{"type": "text", "text": append_text}]}) # Call the LLM response = self.call_llm( { - "model": self.executor_model, + "model": self.model, "messages": messages, "max_tokens": self.max_tokens, "top_p": self.top_p, "temperature": self.temperature, }, - self.executor_model, + self.model, ) logger.info(f"Qwen25VL Output: {response}") @@ -483,10 +483,10 @@ Previous actions: continue # Handle lines inside tool call markers - if line.startswith(""): + if line.startswith("") or line.startswith("⚗") or line.startswith("📐"): # Yeah, it's a bug during data processing inside_tool_call = True continue - elif line.startswith(""): + elif line.startswith("") or line.startswith("⚗") or line.startswith("📐"): # Yeah, it's a bug during data processing if current_tool_call: # Process the collected tool call process_tool_call("\n".join(current_tool_call)) @@ -540,12 +540,13 @@ Previous actions: # todo: check ), interval=30, - max_tries=10, + max_tries=5, ) def call_llm(self, payload, model): messages = payload["messages"] - base_url = "your_base_url" - api_key = "your_api_key" + + base_url = os.getenv('DASHSCOPE_BASE_URL', "https://dashscope.aliyuncs.com/compatible-mode/v1") + api_key = os.getenv('DASHSCOPE_API_KEY', "sk-123") client = openai.OpenAI( base_url=base_url, diff --git a/run_multienv_qwen25vl.py b/run_multienv_qwen25vl.py index 5db9dc8..3d6f44c 100644 --- a/run_multienv_qwen25vl.py +++ b/run_multienv_qwen25vl.py @@ -1,66 +1,34 @@ -"""Script to run end-to-end evaluation on the benchmark. -Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. -""" - +from __future__ import annotations import argparse import datetime import json import logging import os import sys +import signal +import time from typing import List, Dict import math from tqdm import tqdm from multiprocessing import Process, Manager +from multiprocessing import current_process import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.qwen25vl_agent import Qwen25VLAgent +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + # import wandb +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() # Logger Configs {{{ # -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) - -datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - -file_handler = logging.FileHandler( - os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" -) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) -stdout_handler = logging.StreamHandler(sys.stdout) -sdebug_handler = logging.FileHandler( - os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" -) - -file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) -stdout_handler.setLevel(logging.INFO) -sdebug_handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" -) -file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) -stdout_handler.setFormatter(formatter) -sdebug_handler.setFormatter(formatter) - -stdout_handler.addFilter(logging.Filter("desktopenv")) -sdebug_handler.addFilter(logging.Filter("desktopenv")) - -logger.addHandler(file_handler) -logger.addHandler(debug_handler) -logger.addHandler(stdout_handler) -logger.addHandler(sdebug_handler) -# }}} Logger Configs # - -logger = logging.getLogger("desktopenv.experiment") - - def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -80,24 +48,23 @@ def config() -> argparse.Namespace: default="screenshot", help="Observation type", ) - parser.add_argument("--screen_width", type=int, default=1920) - parser.add_argument("--screen_height", type=int, default=1080) - parser.add_argument("--sleep_after_execution", type=float, default=2.0) - parser.add_argument("--max_steps", type=int, default=20) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) # agent config + parser.add_argument("--max_trajectory_length", type=int, default=3) parser.add_argument( "--test_config_base_dir", type=str, default="evaluation_examples" ) # lm config - parser.add_argument("--planner_model", type=str, default=None) - parser.add_argument("--executor_model", type=str, default="aguvis-s1-s2-agentnet0105-mo5") - parser.add_argument("--temperature", type=float, default=0) + parser.add_argument("--model", type=str, default="gpt-4o") + parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--stop_token", type=str, default=None) - + parser.add_argument("--add_thought_prefix", action="store_true", help="Add thought prefix to the response") + # example config parser.add_argument("--domain", type=str, default="all") parser.add_argument( @@ -106,152 +73,304 @@ def config() -> argparse.Namespace: # logging related parser.add_argument("--result_dir", type=str, default="./results") - parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") - + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) args = parser.parse_args() return args +args = config() # Get command line arguments first -def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: - """Distribute tasks evenly across environments.""" - # Flatten the tasks into a single list +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: all_tasks = [] for domain, examples in test_all_meta.items(): for example_id in examples: all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") - # Calculate tasks per environment - tasks_per_env = math.ceil(len(all_tasks) / num_envs) + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) - # Distribute tasks - distributed_tasks = [] - for i in range(num_envs): - env_tasks = {} - start_idx = i * tasks_per_env - end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) - - for domain, example_id in all_tasks[start_idx:end_idx]: - if domain not in env_tasks: - env_tasks[domain] = [] - env_tasks[domain].append(example_id) - - distributed_tasks.append(env_tasks) - - return distributed_tasks - - - -def run_env_tasks(env_idx: int, env: DesktopEnv, agent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): - """Run tasks for a single environment.""" - logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") - - for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): - for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Env {env_idx+1}][Domain]: {domain}") - logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") - logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model), - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - + # Close environment in the current process context + for env in active_environments: + if env is not None: try: - lib_run_single.run_single_example( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, - ) + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") except Exception as e: - logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"Time limit exceeded in {domain}/{example_id}"} - ) - ) - f.write("\n") + logger.error(f"Process {env_idx + 1} error closing environment: {e}") - env.close() + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - logger.info("Args: %s", args) - - distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) - - # First, set up all environments - logger.info("Setting up all environments...") - envs = [] - agents = [] - - for env_idx in range(args.num_envs): - logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") - +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) agent = Qwen25VLAgent( - planner_model=args.planner_model, - executor_model=args.executor_model, + model=args.model, max_tokens=args.max_tokens, top_p=args.top_p, temperature=args.temperature, action_space=args.action_space, + add_thought_prefix=args.add_thought_prefix, ) - agents.append(agent) + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") - env = DesktopEnv( - path_to_vm=args.path_to_vm, - action_space=agent.action_space, - screen_size=(args.screen_width, args.screen_height), - headless=args.headless, - os_type="Ubuntu", - require_a11y_tree=args.observation_type - in ["a11y_tree", "screenshot_a11y_tree", "som"], - provider_name="docker" - ) - envs.append(env) + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes - logger.info("All environments are ready. Starting parallel task execution...") + # Avoid duplicate handling + if is_terminating: + return - # Create a shared list for scores across processes + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") with Manager() as manager: shared_scores = manager.list() - - # Create and start processes for each environment + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs processes = [] - for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + for i in range(num_envs): p = Process( target=run_env_tasks, - args=(env_idx, env, agent, env_tasks, args, shared_scores) + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" ) - processes.append(p) + p.daemon = True p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - # Convert shared list to regular list + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise scores = list(shared_scores) - logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") @@ -330,33 +449,89 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" - args = config() + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} - exp_name = "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model) + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") - test_file_list = get_unfinished( - args.action_space, - exp_name, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") - - get_result( - args.action_space, - exp_name, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}")