feat&fix: enhance error handling during environment initialization and VM allocation

This commit is contained in:
adlsdztony
2025-06-03 13:38:47 +00:00
parent e363da2fd7
commit 8d54d4302f
3 changed files with 277 additions and 95 deletions

View File

@@ -71,29 +71,39 @@ class DesktopEnv(gym.Env):
else: else:
self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region) self.path_to_vm = self.manager.get_vm_path(os_type=self.os_type, region=region)
try:
self.snapshot_name = snapshot_name
self.cache_dir_base: str = cache_dir
# todo: add the logic to get the screen size from the VM
self.headless = headless
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
self.snapshot_name = snapshot_name # Initialize emulator and controller
self.cache_dir_base: str = cache_dir if provider_name != "docker": # Check if this is applicable to other VM providers
# todo: add the logic to get the screen size from the VM logger.info("Initializing...")
self.headless = headless self._start_emulator()
self.require_a11y_tree = require_a11y_tree
self.require_terminal = require_terminal
# Initialize emulator and controller # mode: human or machine
if provider_name != "docker": # Check if this is applicable to other VM providers self.instruction = None
logger.info("Initializing...") assert action_space in ["computer_13", "pyautogui"]
self._start_emulator() self.action_space = action_space # todo: refactor it to the ActType
# mode: human or machine # episodic stuffs, like counters, will be updated or reset
self.instruction = None # when calling self.reset()
assert action_space in ["computer_13", "pyautogui"] self._traj_no: int = -1
self.action_space = action_space # todo: refactor it to the ActType self._step_no: int = 0
self.action_history: List[Dict[str, any]] = []
# episodic stuffs, like counters, will be updated or reset except Exception as e:
# when calling self.reset() logger.error(f"Failed to initialize DesktopEnv: {e}")
self._traj_no: int = -1 # If initialization fails, we should clean up the VM
self._step_no: int = 0 try:
self.action_history: List[Dict[str, any]] = [] self.close()
self.manager.delete_vm(self.path_to_vm, self.region)
logger.info(f"Cleaned up VM {self.path_to_vm}.")
except Exception as cleanup_error:
logger.error(f"Failed to clean up VM {self.path_to_vm}: {cleanup_error}")
raise
def _start_emulator(self): def _start_emulator(self):
# Power on the virtual machine # Power on the virtual machine

View File

@@ -57,11 +57,22 @@ def _allocate_vm(region=DEFAULT_REGION):
} }
ec2_client = boto3.client('ec2', region_name=region) ec2_client = boto3.client('ec2', region_name=region)
response = ec2_client.run_instances(**run_instances_params) try:
instance_id = response['Instances'][0]['InstanceId'] response = ec2_client.run_instances(**run_instances_params)
logger.info(f"Waiting for instance {instance_id} to be running...") instance_id = response['Instances'][0]['InstanceId']
ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id]) logger.info(f"Waiting for instance {instance_id} to be running...")
logger.info(f"Instance {instance_id} is ready.") ec2_client.get_waiter('instance_running').wait(InstanceIds=[instance_id])
logger.info(f"Instance {instance_id} is ready.")
except Exception as e:
logger.error(f"Failed to allocate VM in region {region}: {str(e)}")
# try to clean up any resources that were created
try:
if 'InstanceId' in response['Instances'][0]:
ec2_client.terminate_instances(InstanceIds=[instance_id])
logger.info(f"Terminated instance {instance_id} due to allocation failure.")
except Exception as cleanup_error:
logger.error(f"May fail to clean up instance {instance_id}: {str(cleanup_error)}")
raise
return instance_id return instance_id

View File

@@ -2,12 +2,15 @@
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
""" """
from __future__ import annotations
import argparse import argparse
import datetime import datetime
import json import json
import logging import logging
import os import os
import sys import sys
import signal
import time
from typing import List, Dict from typing import List, Dict
import math import math
from tqdm import tqdm from tqdm import tqdm
@@ -16,6 +19,11 @@ import lib_run_single
from desktop_env.desktop_env import DesktopEnv from desktop_env.desktop_env import DesktopEnv
from mm_agents.openai_cua_agent import OpenAICUAAgent from mm_agents.openai_cua_agent import OpenAICUAAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# import wandb # import wandb
# load the environment variables from .env file # load the environment variables from .env file
@@ -147,8 +155,39 @@ def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
return distributed_tasks return distributed_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
print(f"Local variables in process {env_idx + 1}: {local_vars}")
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list): def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
"""Run tasks for a single environment.""" """Run tasks for a single environment."""
# Each process has its own list of active environments
active_environments = []
env = None
# Setup signal handlers for this process too
signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx))
# ami-05e7d7bd279ea4f14 # ami-05e7d7bd279ea4f14
env = DesktopEnv( env = DesktopEnv(
path_to_vm=args.path_to_vm, path_to_vm=args.path_to_vm,
@@ -163,6 +202,7 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
os_type="Ubuntu", os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
) )
active_environments.append(env)
agent = OpenAICUAAgent( agent = OpenAICUAAgent(
env=env, env=env,
model=args.model, model=args.model,
@@ -175,56 +215,114 @@ def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, share
) )
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): try:
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
config_file = os.path.join( for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
args.test_config_base_dir, f"examples/{domain}/{example_id}.json" config_file = os.path.join(
) args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example_openaicua(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
) )
except Exception as e: with open(config_file, "r", encoding="utf-8") as f:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") example = json.load(f)
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4") logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
) )
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: os.makedirs(example_result_dir, exist_ok=True)
f.write(
json.dumps( try:
{"Error": f"Time limit exceeded in {domain}/{example_id}"} lib_run_single.run_single_example_openaicua(
) agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
) )
f.write("\n") except Exception as e:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
env.close() with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
finally:
# This ensures the environment is closed even if there's an exception
logger.info(f"Process {env_idx + 1} cleaning up environment...")
try:
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal
os.kill(p.pid, signal.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None: def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args) logger.info("Args: %s", args)
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
@@ -244,10 +342,28 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
) )
processes.append(p) processes.append(p)
p.start() p.start()
logger.info(f"Started process {p.name} with PID {p.pid}")
# Wait for all processes to complete try:
for p in processes: # Wait for all processes to complete
p.join() for p in processes:
p.join()
logger.info(f"Process {p.name} completed")
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
# Let the signal handler do the cleanup
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
# Ensure cleanup happens
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
# Convert shared list to regular list # Convert shared list to regular list
scores = list(shared_scores) scores = list(shared_scores)
@@ -331,31 +447,76 @@ if __name__ == "__main__":
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() # Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
with open(args.test_all_meta_path, "r", encoding="utf-8") as f: try:
test_all_meta = json.load(f) args = config()
if args.domain != "all": with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = {args.domain: test_all_meta[args.domain]} test_all_meta = json.load(f)
test_file_list = get_unfinished( if args.domain != "all":
args.action_space, test_all_meta = {args.domain: test_all_meta[args.domain]}
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result( test_file_list = get_unfinished(
args.action_space, args.action_space,
args.model, args.model,
args.observation_type, args.observation_type,
args.result_dir, args.result_dir,
test_all_meta, test_all_meta,
) )
test(args, test_file_list) left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")