Files
sci-gui-agent-benchmark/run_multienv_hosted_gbox.py
2025-11-13 13:13:31 +08:00

526 lines
18 KiB
Python

"""Run OSWorld evaluation using hosted GBOX service"""
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.hosted_gbox_agent import HostedGboxAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run OSWorld evaluation with hosted GBOX service"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument("--max_trajectory_length", type=int, default=3)
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# Hosted GBOX service config
parser.add_argument(
"--gbox_service_url",
type=str,
default=os.getenv("GBOX_SERVICE_URL", "http://44.201.221.203:8000"),
help="URL of hosted GBOX service"
)
parser.add_argument(
"--gbox_service_api_key",
type=str,
default=os.getenv("GBOX_SERVICE_API_KEY"),
help="API key for hosted GBOX service"
)
parser.add_argument(
"--model",
type=str,
default="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
help="Claude model to use (default: Bedrock Sonnet 4.5)"
)
parser.add_argument("--max_tokens", type=int, default=1500)
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results_hosted_gbox")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", help="Cloud provider name"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
parser.add_argument(
"--client_password",
type=str,
default=os.getenv("CLIENT_PASSWORD", "osworld-public-evaluation"),
help="Client password (default: osworld-public-evaluation)"
)
args = parser.parse_args()
return args
# }}} Logger Configs #
def setup_logger(env_idx: int = None, result_dir: str = "./results_gbox", level: str = 'INFO') -> logging.Logger:
"""Set up a logger for the current process.
Args:
env_idx: Environment index for naming (None for main process)
result_dir: Directory to store logs
level: Logging level
Returns:
Configured logger instance
"""
# Set log level
numeric_level = getattr(logging, level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {level}')
# Create logger
if env_idx is not None:
logger_name = f"osworld-worker-{env_idx}"
else:
logger_name = "osworld-main"
logger = logging.getLogger(logger_name)
logger.setLevel(numeric_level)
# Remove existing handlers
logger.handlers.clear()
# Create formatters and handlers
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(numeric_level)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
# File handler
os.makedirs(result_dir, exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if env_idx is not None:
log_file = os.path.join(result_dir, f"worker_{env_idx}_{timestamp}.log")
else:
log_file = os.path.join(result_dir, f"main_{timestamp}.log")
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(numeric_level)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
logger = logging.getLogger("osworld-main")
def check_completed_tasks(result_dir: str, test_all_meta: dict) -> List[str]:
"""Check which tasks have already been completed.
Args:
result_dir: Directory containing results
test_all_meta: Dictionary of domain -> list of task IDs
Returns:
List of completed task IDs (format: "domain/task_id")
"""
completed = []
for domain, examples in test_all_meta.items():
for example_id in examples:
result_path = os.path.join(
result_dir,
"pyautogui",
"screenshot",
"claude-sonnet-4-5", # Model name from args
domain,
example_id,
"result.txt"
)
if os.path.exists(result_path):
completed.append(f"{domain}/{example_id}")
logger.info(f"Task {domain}/{example_id} already completed (result found)")
return completed
def report_current_results(target_dir: str) -> List[float]:
"""Report current results from completed tasks.
Args:
target_dir: Directory containing results
Returns:
List of scores (0.0 or 1.0)
"""
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
try:
with open(os.path.join(example_path, "result.txt"), "r") as f:
all_result.append(float(f.read()))
except Exception as e:
logger.warning(f"Failed to read result for {domain}/{example_id}: {e}")
all_result.append(0.0)
if not all_result:
logger.info("New experiment, no results yet.")
return None
else:
success_rate = sum(all_result) / len(all_result) * 100
logger.info(f"Current Success Rate: {success_rate:.2f}% ({len(all_result)} tasks)")
return all_result
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue, args: argparse.Namespace, shared_scores: list):
"""Worker process that runs tasks from the queue using hosted GBOX service."""
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
# Create environment
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
# Get VM IP address - MCP server will handle public IP lookup if needed
vm_ip = env.vm_ip
logger.info(f"VM IP: {vm_ip}")
# Create hosted GBOX agent
agent = HostedGboxAgent(
server_url=args.gbox_service_url,
api_key=args.gbox_service_api_key,
vm_ip=vm_ip,
platform="ubuntu",
model=args.model,
max_steps=args.max_steps,
)
# Process tasks from queue
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Domain]: {domain}")
logger.info(f"[Example ID]: {example_id}")
logger.info(f"[Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Error processing task: {e}", exc_info=True)
except KeyboardInterrupt:
logger.info("Worker received interrupt signal")
except Exception as e:
logger.error(f"Worker error: {e}", exc_info=True)
finally:
# Cleanup
if env is not None:
try:
logger.info("Closing environment...")
env.close()
logger.info("Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
def main_signal_handler(signum, frame):
"""Signal handler for main process to gracefully shut down all child processes."""
global is_terminating
if is_terminating:
logger.info("Already terminating, please wait...")
return
is_terminating = True
logger.info(f"Main process received signal {signum}. Shutting down all workers...")
# Terminate all child processes
for idx, proc in enumerate(processes):
if proc.is_alive():
logger.info(f"Terminating worker process {idx + 1}...")
proc.terminate()
# Wait for processes to finish with timeout
timeout = 30
start_time = time.time()
for idx, proc in enumerate(processes):
remaining_time = max(0, timeout - (time.time() - start_time))
proc.join(timeout=remaining_time)
if proc.is_alive():
logger.warning(f"Worker {idx + 1} did not terminate gracefully, forcing...")
proc.kill()
proc.join()
logger.info("All workers terminated. Exiting.")
sys.exit(0)
if __name__ == "__main__":
args = config()
# Setup main logger
logger = setup_logger(env_idx=None, result_dir=args.result_dir, level=args.log_level)
# Validate hosted service configuration
if not args.gbox_service_url:
logger.error("GBOX_SERVICE_URL not set (use --gbox_service_url or GBOX_SERVICE_URL env var)")
sys.exit(1)
if not args.gbox_service_api_key:
logger.error("GBOX_SERVICE_API_KEY not set (use --gbox_service_api_key or GBOX_SERVICE_API_KEY env var)")
sys.exit(1)
logger.info(f"Using hosted GBOX service at: {args.gbox_service_url}")
logger.info(f"Model: {args.model}")
logger.info(f"Max steps: {args.max_steps}")
logger.info(f"Number of parallel environments: {args.num_envs}")
# Setup signal handlers
signal.signal(signal.SIGINT, main_signal_handler)
signal.signal(signal.SIGTERM, main_signal_handler)
# Load test configuration
logger.info(f"Loading test configuration from: {args.test_all_meta_path}")
with open(args.test_all_meta_path, "r") as f:
test_all_meta = json.load(f)
# Filter by domain if specified
if args.domain != "all":
if args.domain in test_all_meta:
test_all_meta = {args.domain: test_all_meta[args.domain]}
logger.info(f"Filtering to domain: {args.domain}")
else:
logger.error(f"Domain '{args.domain}' not found in test configuration")
sys.exit(1)
# Check for completed tasks
completed_tasks = check_completed_tasks(args.result_dir, test_all_meta)
logger.info(f"Found {len(completed_tasks)} completed tasks")
# Distribute tasks
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks to run: {len(all_tasks)}")
# Filter out completed tasks
all_tasks = [task for task in all_tasks if f"{task[0]}/{task[1]}" not in completed_tasks]
logger.info(f"Tasks remaining after filtering completed: {len(all_tasks)}")
if not all_tasks:
logger.info("No tasks to run. All tasks already completed.")
# Report current results
target_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model if getattr(args, 'model_dir_name', None) is None else args.model_dir_name
)
if os.path.exists(target_dir):
report_current_results(target_dir)
sys.exit(0)
# Create shared task queue
manager = Manager()
task_queue = manager.Queue()
shared_scores = manager.list()
# Populate queue
for task in all_tasks:
task_queue.put(task)
# Start worker processes
logger.info(f"Starting {args.num_envs} worker processes...")
for env_idx in range(args.num_envs):
proc = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores)
)
proc.start()
processes.append(proc)
logger.info(f"Started worker process {env_idx + 1} (PID: {proc.pid})")
# Wait for all processes to complete
try:
for idx, proc in enumerate(processes):
proc.join()
logger.info(f"Worker process {idx + 1} completed")
except KeyboardInterrupt:
logger.info("Received interrupt, shutting down...")
main_signal_handler(signal.SIGINT, None)
# Report final results
logger.info("=" * 50)
logger.info("EVALUATION COMPLETE")
logger.info("=" * 50)
target_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model
)
if os.path.exists(target_dir):
final_results = report_current_results(target_dir)
if final_results:
success_rate = sum(final_results) / len(final_results) * 100
logger.info(f"Final Success Rate: {success_rate:.2f}% ({len(final_results)} tasks)")
logger.info("Exiting...")