feat: add run_multienv_o3.py script for multi-environment evaluation

- Introduced a new script `run_multienv_o3.py` to facilitate end-to-end evaluation across multiple environments.
- Implemented command-line argument parsing for various configurations including environment settings, logging levels, and AWS parameters.
- Integrated signal handling for graceful shutdown of environments and processes.
- Enhanced logging capabilities for better traceability during execution.
- Maintained existing logic from previous scripts while introducing new functionalities for improved evaluation processes.
This commit is contained in:
yuanmengqi
2025-07-27 16:47:24 +00:00
parent 1342bfe5ce
commit 0f00788c4d
5 changed files with 1148 additions and 209 deletions

261
mm_agents/o3_agent.py Normal file
View File

@@ -0,0 +1,261 @@
import base64
import logging
import os
import re
from io import BytesIO
from typing import Dict, List
import backoff
import openai
import requests
from PIL import Image
from requests.exceptions import SSLError
from mm_agents.prompts import O3_SYSTEM_PROMPT
logger = None
MAX_RETRY_TIMES = 10
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY",None) #"Your OpenAI API Key"
def encode_image(image_content):
return base64.b64encode(image_content).decode("utf-8")
class O3Agent:
def __init__(
self,
platform="ubuntu",
model="o3",
max_tokens=1500,
client_password="password",
action_space="pyautogui",
observation_type="screenshot",
max_steps=15
):
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.client_password = client_password
self.action_space = action_space
self.observation_type = observation_type
assert action_space in ["pyautogui"], "Invalid action space"
assert observation_type in ["screenshot"], "Invalid observation type"
self.thoughts = []
self.actions = []
self.observations = []
self.observation_captions = []
self.max_image_history_length = 5
self.current_step = 1
self.max_steps = max_steps
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
user_prompt = (
f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""")
messages = [{
"role": "system",
"content": [{
"type": "text",
"text": O3_SYSTEM_PROMPT.format(
current_step=self.current_step,
max_steps=self.max_steps,
CLIENT_PASSWORD=self.client_password
)
}]
}]
# Determine which observations to include images for (only most recent ones)
obs_start_idx = max(0, len(self.observations) - self.max_image_history_length)
# Add all thought and action history
for i in range(len(self.thoughts)):
# For recent steps, include the actual screenshot
if i >= obs_start_idx:
messages.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}",
"detail": "high"
},
}]
})
# For older steps, use the observation caption instead of the image
else:
messages.append({
"role": "user",
"content": [{
"type": "text",
"text": f"Observation: {self.observation_captions[i]}"
}]
})
thought_messages = f"Thought:\n{self.thoughts[i]}"
action_messages = f"Action:"
for action in self.actions[i]:
action_messages += f"\n{action}"
messages.append({
"role": "assistant",
"content": [{
"type": "text",
"text": thought_messages + "\n" + action_messages
}]
})
messages.append({
"role":"user",
"content": [
{
"type":"image_url",
"image_url":{
"url":f"data:image/png;base64,{encode_image(obs['screenshot'])}",
"detail": "high"
},
},
{
"type": "text",
"text": user_prompt
},
],
})
response = self.call_llm(
{
"model": self.model,
"messages": messages,
"max_completion_tokens": self.max_tokens,
},
self.model,
)
logger.info(f"Output: {response}")
codes = self.parse_code_from_planner_response(response)
# Add retry logic if no codes were parsed
retry_count = 0
max_retries = MAX_RETRY_TIMES
while not codes and retry_count < max_retries:
logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...")
messages.append({
"role": "user",
"content": [
{"type": "text", "text": "You didn't generate valid actions. Please try again."}
]
})
response = self.call_llm(
{
"model": self.model,
"messages": messages,
"max_completion_tokens": self.max_tokens,
},
self.model,
)
logger.info(f"Retry Planner Output: {response}")
codes = self.parse_code_from_planner_response(response)
retry_count += 1
thought = self.parse_thought_from_planner_response(response)
observation_caption = self.parse_observation_caption_from_planner_response(response)
logger.info(f"Thought: {thought}")
logger.info(f"Observation Caption: {observation_caption}")
logger.info(f"Codes: {codes}")
self.actions.append([codes])
self.observations.append(obs)
self.thoughts.append(thought)
self.observation_captions.append(observation_caption)
self.current_step += 1
return response, codes
def parse_observation_caption_from_planner_response(self, input_string: str) -> str:
pattern = r"Observation:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_thought_from_planner_response(self, input_string: str) -> str:
pattern = r"Thought:\n(.*?)\n"
matches = re.findall(pattern, input_string, re.DOTALL)
if matches:
return matches[0].strip()
return ""
def parse_code_from_planner_response(self, input_string: str) -> List[str]:
input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()])
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
pattern = r"```(?:\w+\s+)?(.*?)```"
matches = re.findall(pattern, input_string, re.DOTALL)
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL']
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
@backoff.on_exception(
backoff.constant,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure
# each example won't exceed the time limit
(
# General exceptions
SSLError,
requests.HTTPError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
openai.APIConnectionError,
openai.APIError
),
interval=30,
max_tries=10,
)
def call_llm(self, payload, model):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {OPENAI_API_KEY}"
}
logger.info("Generating content with GPT model: %s", model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload,
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
# Raise HTTPError to trigger backoff retry mechanism
response.raise_for_status()
else:
return response.json()["choices"][0]["message"]["content"]
def reset(self, _logger=None):
global logger
logger = (_logger if _logger is not None else
logging.getLogger("desktopenv.o3_agent"))
self.thoughts = []
self.action_descriptions = []
self.actions = []
self.observations = []
self.observation_captions = []

3
run.py
View File

@@ -15,8 +15,7 @@ import lib_run_single
from desktop_env.desktop_env import DesktopEnv from desktop_env.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent from mm_agents.agent import PromptAgent
# import wandb # Almost deprecated since it's not multi-env, use run_multienv_*.py instead
# Logger Configs {{{ # # Logger Configs {{{ #
logger = logging.getLogger() logger = logging.getLogger()

View File

@@ -1,66 +1,32 @@
"""Script to run end-to-end evaluation on the benchmark. from __future__ import annotations
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse import argparse
import datetime import datetime
import json import json
import logging import logging
import os import os
import sys import sys
from typing import List, Dict import signal
import math import time
from tqdm import tqdm from typing import List
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single import lib_run_single
from desktop_env.desktop_env import DesktopEnv from desktop_env.desktop_env import DesktopEnv
from mm_agents.agent import PromptAgent from mm_agents.agent import PromptAgent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# import wandb # import wandb
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ # # Logger Configs {{{ #
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
sdebug_handler = logging.FileHandler(
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(logging.INFO)
sdebug_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
sdebug_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
sdebug_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
logger.addHandler(sdebug_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark" description="Run end-to-end evaluation on the benchmark"
@@ -77,11 +43,9 @@ def config() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--observation_type", "--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="a11y_tree", default="screenshot",
help="Observation type", help="Observation type",
) )
parser.add_argument("--screen_width", type=int, default=1920)
parser.add_argument("--screen_height", type=int, default=1080)
parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15) parser.add_argument("--max_steps", type=int, default=15)
@@ -107,111 +71,116 @@ def config() -> argparse.Namespace:
# logging related # logging related
parser.add_argument("--result_dir", type=str, default="./results") parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config # aws config
parser.add_argument( parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM" "--region", type=str, default="us-east-1", help="AWS region for the VM"
) )
parser.add_argument(
"--provider_name", type=str, default="docker", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args() args = parser.parse_args()
return args return args
args = config() # Get command line arguments first
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: logger = logging.getLogger()
"""Distribute tasks evenly across environments.""" log_level = getattr(logging, args.log_level.upper())
# Flatten the tasks into a single list logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = [] all_tasks = []
for domain, examples in test_all_meta.items(): for domain, examples in test_all_meta.items():
for example_id in examples: for example_id in examples:
all_tasks.append((domain, example_id)) all_tasks.append((domain, example_id))
return all_tasks
# Calculate tasks per environment
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
# Distribute tasks
distributed_tasks = []
for i in range(num_envs):
env_tasks = {}
start_idx = i * tasks_per_env
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
for domain, example_id in all_tasks[start_idx:end_idx]:
if domain not in env_tasks:
env_tasks[domain] = []
env_tasks[domain].append(example_id)
distributed_tasks.append(env_tasks)
return distributed_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): # Get the active_environments from the caller's frame
"""Run tasks for a single environment.""" local_vars = frame.f_locals
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") active_environments = local_vars.get('active_environments', [])
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
# Close environment in the current process context
for env in active_environments:
if env is not None:
try: try:
lib_run_single.run_single_example( logger.info(f"Process {env_idx + 1} closing environment...")
agent, env.close()
env, logger.info(f"Process {env_idx + 1} environment closed successfully")
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e: except Exception as e:
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") logger.error(f"Process {env_idx + 1} error closing environment: {e}")
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
)
)
f.write("\n")
env.close() logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None: def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
logger.info("Args: %s", args) active_environments = []
env = None
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
# First, set up all environments REGION = args.region
logger.info("Setting up all environments...") screen_size = (args.screen_width, args.screen_height)
envs = [] ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
agents = [] env = DesktopEnv(
path_to_vm=args.path_to_vm,
for env_idx in range(args.num_envs): action_space=args.action_space,
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = PromptAgent( agent = PromptAgent(
model=args.model, model=args.model,
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
@@ -220,48 +189,188 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
action_space=args.action_space, action_space=args.action_space,
observation_type=args.observation_type, observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length, max_trajectory_length=args.max_trajectory_length,
client_password=args.client_password
) )
agents.append(agent)
env = DesktopEnv( logger.info(f"Process {current_process().name} started.")
path_to_vm=args.path_to_vm, while True:
action_space=agent.action_space, try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
provider_name="aws",
region="us-east-1",
snapshot_name="ami-05e7d7bd279ea4f14",
screen_size=(args.screen_width, args.screen_height), def signal_handler(signum, frame):
headless=args.headless, """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
os_type="Ubuntu", global is_terminating, active_environments, processes
require_a11y_tree=args.observation_type
in ["a11y_tree", "screenshot_a11y_tree", "som"],
)
envs.append(env)
logger.info("All environments are ready. Starting parallel task execution...") # Avoid duplicate handling
if is_terminating:
return
# Create a shared list for scores across processes is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager: with Manager() as manager:
shared_scores = manager.list() shared_scores = manager.list()
task_queue = manager.Queue()
# Create and start processes for each environment for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = [] processes = []
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): for i in range(num_envs):
p = Process( p = Process(
target=run_env_tasks, target=run_env_tasks,
args=(env_idx, env, agent, env_tasks, args, shared_scores) args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
) )
processes.append(p) p.daemon = True
p.start() p.start()
processes.append(p)
# Wait for all processes to complete logger.info(f"Started process {p.name} with PID {p.pid}")
for p in processes: try:
p.join() while True:
alive_count = 0
# Convert shared list to regular list for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores) scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
@@ -341,46 +450,88 @@ if __name__ == "__main__":
####### The complete version of the list of examples ####### ####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() # Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
# save args to json in result_dir/action_space/observation_type/model/args.json try:
path_to_args = os.path.join( args = config()
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f: # save args to json in result_dir/action_space/observation_type/model/args.json
test_all_meta = json.load(f) path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
if args.domain != "all": with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = {args.domain: test_all_meta[args.domain]} test_all_meta = json.load(f)
test_file_list = get_unfinished( if args.domain != "all":
args.action_space, test_all_meta = {args.domain: test_all_meta[args.domain]}
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result( test_file_list = get_unfinished(
args.action_space, args.action_space,
args.model, args.model,
args.observation_type, args.observation_type,
args.result_dir, args.result_dir,
test_all_meta, test_all_meta,
) )
test(args, test_file_list) left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# path_to_vm can be a list["xxx","xxx"] # First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

529
run_multienv_o3.py Normal file
View File

@@ -0,0 +1,529 @@
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List, Dict
import math
from tqdm import tqdm
from multiprocessing import Process, Manager
from multiprocessing import current_process
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.o3_agent import O3Agent
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# import wandb
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=15)
# agent config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config
parser.add_argument("--model", type=str, default="o3")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
enable_proxy=True,
client_password=args.client_password
)
active_environments.append(env)
agent = O3Agent(
max_steps=args.max_steps,
client_password=args.client_password,
action_space=args.action_space,
observation_type=args.observation_type,
)
logger.info(f"Process {current_process().name} started.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
shared_scores,
)
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"EnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
if __name__ == "__main__":
####### The complete version of the list of examples #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")

View File

@@ -207,7 +207,6 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li
max_tokens=args.max_tokens, max_tokens=args.max_tokens,
top_p=args.top_p, top_p=args.top_p,
temperature=args.temperature, temperature=args.temperature,
max_trajectory_length=args.max_trajectory_length, max_trajectory_length=args.max_trajectory_length,
max_image_history_length=args.max_image_history_length, max_image_history_length=args.max_image_history_length,
use_thinking=args.use_thinking, use_thinking=args.use_thinking,