From 8d54d4302f8f5d4450642a4d4703dcfc02c46100 Mon Sep 17 00:00:00 2001 From: adlsdztony Date: Tue, 3 Jun 2025 13:38:47 +0000 Subject: [PATCH 1/2] feat&fix: enhance error handling during environment initialization and VM allocation --- desktop_env/desktop_env.py | 50 +++-- desktop_env/providers/aws/manager.py | 21 +- run_multienv_openaicua.py | 301 ++++++++++++++++++++------- 3 files changed, 277 insertions(+), 95 deletions(-) diff --git a/desktop_env/desktop_env.py b/desktop_env/desktop_env.py index 6cb58e7..4a24a55 100644 --- a/desktop_env/desktop_env.py +++ b/desktop_env/desktop_env.py @@ -71,29 +71,39 @@ class DesktopEnv(gym.Env): else: self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region) + try: + self.snapshot_name = snapshot_name + self.cache_dir_base: str = cache_dir + # todo: add the logic to get the screen size from the VM + self.headless = headless + self.require_a11y_tree = require_a11y_tree + self.require_terminal = require_terminal - self.snapshot_name = snapshot_name - self.cache_dir_base: str = cache_dir - # todo: add the logic to get the screen size from the VM - self.headless = headless - self.require_a11y_tree = require_a11y_tree - self.require_terminal = require_terminal + # Initialize emulator and controller + if provider_name != "docker": # Check if this is applicable to other VM providers + logger.info("Initializing...") + self._start_emulator() - # Initialize emulator and controller - if provider_name != "docker": # Check if this is applicable to other VM providers - logger.info("Initializing...") - self._start_emulator() + # mode: human or machine + self.instruction = None + assert action_space in ["computer_13", "pyautogui"] + self.action_space = action_space # todo: refactor it to the ActType - # mode: human or machine - self.instruction = None - assert action_space in ["computer_13", "pyautogui"] - self.action_space = action_space # todo: refactor it to the ActType - - # episodic stuffs, like counters, will be updated or reset - # when calling self.reset() - self._traj_no: int = -1 - self._step_no: int = 0 - self.action_history: List[Dict[str, any]] = [] + # episodic stuffs, like counters, will be updated or reset + # when calling self.reset() + self._traj_no: int = -1 + self._step_no: int = 0 + self.action_history: List[Dict[str, any]] = [] + except Exception as e: + logger.error(f"Failed to initialize DesktopEnv: {e}") + # If initialization fails, we should clean up the VM + try: + self.close() + self.manager.delete_vm(self.path_to_vm, self.region) + logger.info(f"Cleaned up VM {self.path_to_vm}.") + except Exception as cleanup_error: + logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}") + raise def _start_emulator(self): # Power on the virtual machine diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index a9fbf3f..7e45720 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -57,11 +57,22 @@ def _allocate_vm(region=DEFAULT_REGION): } ec2_client = boto3.client('ec2', region_name=region) - response = ec2_client.run_instances(**run_instances_params) - instance_id = response['Instances'][0]['InstanceId'] - logger.info(f"Waiting for instance {instance_id} to be running...") - ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) - logger.info(f"Instance {instance_id} is ready.") + try: + response = ec2_client.run_instances(**run_instances_params) + instance_id = response['Instances'][0]['InstanceId'] + logger.info(f"Waiting for instance {instance_id} to be running...") + ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) + logger.info(f"Instance {instance_id} is ready.") + except Exception as e: + logger.error(f"Failed to allocate VM in region {region}: {str(e)}") + # try to clean up any resources that were created + try: + if 'InstanceId' in response['Instances'][0]: + ec2_client.terminate_instances(InstanceIds=[instance_id]) + logger.info(f"Terminated instance {instance_id} due to allocation failure.") + except Exception as cleanup_error: + logger.error(f"May fail to clean up instance {instance_id}: {str(cleanup_error)}") + raise return instance_id diff --git a/run_multienv_openaicua.py b/run_multienv_openaicua.py index 57ca019..fb8cda2 100644 --- a/run_multienv_openaicua.py +++ b/run_multienv_openaicua.py @@ -2,12 +2,15 @@ Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. """ +from __future__ import annotations import argparse import datetime import json import logging import os import sys +import signal +import time from typing import List, Dict import math from tqdm import tqdm @@ -16,6 +19,11 @@ import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.openai_cua_agent import OpenAICUAAgent +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + # import wandb # load the environment variables from .env file @@ -147,8 +155,39 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: return distributed_tasks +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + print(f"Local variables in process {env_idx + 1}: {local_vars}") + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + + def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list): """Run tasks for a single environment.""" + # Each process has its own list of active environments + active_environments = [] + env = None + + # Setup signal handlers for this process too + signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) + signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) + # ami-05e7d7bd279ea4f14 env = DesktopEnv( path_to_vm=args.path_to_vm, @@ -163,6 +202,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share os_type="Ubuntu", require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], ) + active_environments.append(env) agent = OpenAICUAAgent( env=env, model=args.model, @@ -175,56 +215,114 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share ) logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") - for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): - for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Env {env_idx+1}][Domain]: {domain}") - logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") - logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - - try: - lib_run_single.run_single_example_openaicua( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, + try: + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" ) - except Exception as e: - logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, ) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"Time limit exceeded in {domain}/{example_id}"} - ) + os.makedirs(example_result_dir, exist_ok=True) + + try: + lib_run_single.run_single_example_openaicua( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, ) - f.write("\n") + except Exception as e: + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) + f.write("\n") + finally: + # This ensures the environment is closed even if there's an exception + logger.info(f"Process {env_idx + 1} cleaning up environment...") + try: + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}") + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes - env.close() + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal + os.kill(p.pid, signal.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes logger.info("Args: %s", args) distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) @@ -244,10 +342,28 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: ) processes.append(p) p.start() + logger.info(f"Started process {p.name} with PID {p.pid}") - # Wait for all processes to complete - for p in processes: - p.join() + try: + # Wait for all processes to complete + for p in processes: + p.join() + logger.info(f"Process {p.name} completed") + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + # Let the signal handler do the cleanup + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + # Ensure cleanup happens + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise # Convert shared list to regular list scores = list(shared_scores) @@ -331,31 +447,76 @@ if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" - args = config() + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") From 10153ffff654b8a2092a212a771a9e4faca9ed57 Mon Sep 17 00:00:00 2001 From: adlsdztony Date: Wed, 4 Jun 2025 03:15:30 +0000 Subject: [PATCH 2/2] feat&fix: add signal handling for VM allocation and improve cleanup on termination --- desktop_env/providers/aws/manager.py | 43 +++++++++++++++++++++++++++- run_multienv_openaicua.py | 1 - 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/desktop_env/providers/aws/manager.py b/desktop_env/providers/aws/manager.py index 7e45720..b9925a0 100644 --- a/desktop_env/providers/aws/manager.py +++ b/desktop_env/providers/aws/manager.py @@ -4,6 +4,7 @@ import boto3 import psutil import logging import dotenv +import signal # Load environment variables from .env file dotenv.load_dotenv() @@ -57,22 +58,62 @@ def _allocate_vm(region=DEFAULT_REGION): } ec2_client = boto3.client('ec2', region_name=region) + instance_id = None + original_sigint_handler = signal.getsignal(signal.SIGINT) + original_sigterm_handler = signal.getsignal(signal.SIGTERM) + + def signal_handler(sig, frame): + if instance_id: + signal_name = "SIGINT" if sig == signal.SIGINT else "SIGTERM" + logger.warning(f"Received {signal_name} signal, terminating instance {instance_id}...") + try: + ec2_client.terminate_instances(InstanceIds=[instance_id]) + logger.info(f"Successfully terminated instance {instance_id} after {signal_name}.") + except Exception as cleanup_error: + logger.error(f"Failed to terminate instance {instance_id} after {signal_name}: {str(cleanup_error)}") + + # Restore original signal handlers + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) + + # Raise appropriate exception based on signal type + if sig == signal.SIGINT: + raise KeyboardInterrupt + else: + # For SIGTERM, exit gracefully + import sys + sys.exit(0) + try: + # Set up signal handlers for both SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + response = ec2_client.run_instances(**run_instances_params) instance_id = response['Instances'][0]['InstanceId'] logger.info(f"Waiting for instance {instance_id} to be running...") ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) logger.info(f"Instance {instance_id} is ready.") + except KeyboardInterrupt: + logger.warning("VM allocation interrupted by user (SIGINT).") + raise + except SystemExit: + logger.warning("VM allocation terminated by parent process (SIGTERM).") + raise except Exception as e: logger.error(f"Failed to allocate VM in region {region}: {str(e)}") # try to clean up any resources that were created try: - if 'InstanceId' in response['Instances'][0]: + if instance_id: ec2_client.terminate_instances(InstanceIds=[instance_id]) logger.info(f"Terminated instance {instance_id} due to allocation failure.") except Exception as cleanup_error: logger.error(f"May fail to clean up instance {instance_id}: {str(cleanup_error)}") raise + finally: + # Restore original signal handlers + signal.signal(signal.SIGINT, original_sigint_handler) + signal.signal(signal.SIGTERM, original_sigterm_handler) return instance_id diff --git a/run_multienv_openaicua.py b/run_multienv_openaicua.py index fb8cda2..816ee5c 100644 --- a/run_multienv_openaicua.py +++ b/run_multienv_openaicua.py @@ -161,7 +161,6 @@ def process_signal_handler(signum, frame, env_idx): # Get the active_environments from the caller's frame local_vars = frame.f_locals - print(f"Local variables in process {env_idx + 1}: {local_vars}") active_environments = local_vars.get('active_environments', []) # Close environment in the current process context