diff --git a/mm_agents/o3_agent.py b/mm_agents/o3_agent.py new file mode 100644 index 0000000..f50730e --- /dev/null +++ b/mm_agents/o3_agent.py @@ -0,0 +1,261 @@ +import base64 +import logging +import os +import re +from io import BytesIO +from typing import Dict, List + + +import backoff +import openai +import requests +from PIL import Image +from requests.exceptions import SSLError +from mm_agents.prompts import O3_SYSTEM_PROMPT + +logger = None +MAX_RETRY_TIMES = 10 + +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY",None) #"Your OpenAI API Key" + +def encode_image(image_content): + return base64.b64encode(image_content).decode("utf-8") + +class O3Agent: + def __init__( + self, + platform="ubuntu", + model="o3", + max_tokens=1500, + client_password="password", + action_space="pyautogui", + observation_type="screenshot", + max_steps=15 + ): + self.platform = platform + self.model = model + self.max_tokens = max_tokens + self.client_password = client_password + self.action_space = action_space + self.observation_type = observation_type + assert action_space in ["pyautogui"], "Invalid action space" + assert observation_type in ["screenshot"], "Invalid observation type" + self.thoughts = [] + self.actions = [] + self.observations = [] + self.observation_captions = [] + self.max_image_history_length = 5 + self.current_step = 1 + self.max_steps = max_steps + + def predict(self, instruction: str, obs: Dict) -> List: + """ + Predict the next action(s) based on the current observation. + """ + + user_prompt = ( + f"""Please generate the next move according to the UI screenshot and instruction. And you can refer to the previous actions and observations for reflection.\n\nInstruction: {instruction}\n\n""") + + messages = [{ + "role": "system", + "content": [{ + "type": "text", + "text": O3_SYSTEM_PROMPT.format( + current_step=self.current_step, + max_steps=self.max_steps, + CLIENT_PASSWORD=self.client_password + ) + }] + }] + + # Determine which observations to include images for (only most recent ones) + obs_start_idx = max(0, len(self.observations) - self.max_image_history_length) + + # Add all thought and action history + for i in range(len(self.thoughts)): + # For recent steps, include the actual screenshot + if i >= obs_start_idx: + messages.append({ + "role": "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{encode_image(self.observations[i]['screenshot'])}", + "detail": "high" + }, + }] + }) + # For older steps, use the observation caption instead of the image + else: + messages.append({ + "role": "user", + "content": [{ + "type": "text", + "text": f"Observation: {self.observation_captions[i]}" + }] + }) + + thought_messages = f"Thought:\n{self.thoughts[i]}" + + action_messages = f"Action:" + for action in self.actions[i]: + action_messages += f"\n{action}" + messages.append({ + "role": "assistant", + "content": [{ + "type": "text", + "text": thought_messages + "\n" + action_messages + }] + }) + + messages.append({ + "role":"user", + "content": [ + { + "type":"image_url", + "image_url":{ + "url":f"data:image/png;base64,{encode_image(obs['screenshot'])}", + "detail": "high" + }, + }, + { + "type": "text", + "text": user_prompt + }, + ], + }) + + response = self.call_llm( + { + "model": self.model, + "messages": messages, + "max_completion_tokens": self.max_tokens, + }, + self.model, + ) + + logger.info(f"Output: {response}") + codes = self.parse_code_from_planner_response(response) + # Add retry logic if no codes were parsed + retry_count = 0 + max_retries = MAX_RETRY_TIMES + while not codes and retry_count < max_retries: + logger.info(f"No codes parsed from planner response. Retrying ({retry_count+1}/{max_retries})...") + messages.append({ + "role": "user", + "content": [ + {"type": "text", "text": "You didn't generate valid actions. Please try again."} + ] + }) + response = self.call_llm( + { + "model": self.model, + "messages": messages, + "max_completion_tokens": self.max_tokens, + }, + self.model, + ) + logger.info(f"Retry Planner Output: {response}") + codes = self.parse_code_from_planner_response(response) + retry_count += 1 + + thought = self.parse_thought_from_planner_response(response) + observation_caption = self.parse_observation_caption_from_planner_response(response) + logger.info(f"Thought: {thought}") + logger.info(f"Observation Caption: {observation_caption}") + logger.info(f"Codes: {codes}") + self.actions.append([codes]) + self.observations.append(obs) + self.thoughts.append(thought) + self.observation_captions.append(observation_caption) + self.current_step += 1 + return response, codes + + def parse_observation_caption_from_planner_response(self, input_string: str) -> str: + pattern = r"Observation:\n(.*?)\n" + matches = re.findall(pattern, input_string, re.DOTALL) + if matches: + return matches[0].strip() + return "" + + def parse_thought_from_planner_response(self, input_string: str) -> str: + pattern = r"Thought:\n(.*?)\n" + matches = re.findall(pattern, input_string, re.DOTALL) + if matches: + return matches[0].strip() + return "" + + def parse_code_from_planner_response(self, input_string: str) -> List[str]: + + input_string = "\n".join([line.strip() for line in input_string.split(';') if line.strip()]) + if input_string.strip() in ['WAIT', 'DONE', 'FAIL']: + return [input_string.strip()] + + pattern = r"```(?:\w+\s+)?(.*?)```" + matches = re.findall(pattern, input_string, re.DOTALL) + codes = [] + + for match in matches: + match = match.strip() + commands = ['WAIT', 'DONE', 'FAIL'] + + if match in commands: + codes.append(match.strip()) + elif match.split('\n')[-1] in commands: + if len(match.split('\n')) > 1: + codes.append("\n".join(match.split('\n')[:-1])) + codes.append(match.split('\n')[-1]) + else: + codes.append(match) + + return codes + + @backoff.on_exception( + backoff.constant, + # here you should add more model exceptions as you want, + # but you are forbidden to add "Exception", that is, a common type of exception + # because we want to catch this kind of Exception in the outside to ensure + # each example won't exceed the time limit + ( + # General exceptions + SSLError, + requests.HTTPError, + # OpenAI exceptions + openai.RateLimitError, + openai.BadRequestError, + openai.InternalServerError, + openai.APIConnectionError, + openai.APIError + ), + interval=30, + max_tries=10, + ) + def call_llm(self, payload, model): + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {OPENAI_API_KEY}" + } + logger.info("Generating content with GPT model: %s", model) + response = requests.post( + "https://api.openai.com/v1/chat/completions", + headers=headers, + json=payload, + ) + + if response.status_code != 200: + logger.error("Failed to call LLM: " + response.text) + # Raise HTTPError to trigger backoff retry mechanism + response.raise_for_status() + else: + return response.json()["choices"][0]["message"]["content"] + + def reset(self, _logger=None): + global logger + logger = (_logger if _logger is not None else + logging.getLogger("desktopenv.o3_agent")) + + self.thoughts = [] + self.action_descriptions = [] + self.actions = [] + self.observations = [] + self.observation_captions = [] diff --git a/run.py b/run.py index 3730902..cce45bc 100644 --- a/run.py +++ b/run.py @@ -15,8 +15,7 @@ import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent -# import wandb - +# Almost deprecated since it's not multi-env, use run_multienv_*.py instead # Logger Configs {{{ # logger = logging.getLogger() diff --git a/run_multienv.py b/run_multienv.py index 3b1d005..7ab33fe 100644 --- a/run_multienv.py +++ b/run_multienv.py @@ -1,66 +1,32 @@ -"""Script to run end-to-end evaluation on the benchmark. -Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. -""" - +from __future__ import annotations import argparse import datetime import json import logging import os import sys -from typing import List, Dict -import math -from tqdm import tqdm +import signal +import time +from typing import List from multiprocessing import Process, Manager +from multiprocessing import current_process import lib_run_single from desktop_env.desktop_env import DesktopEnv from mm_agents.agent import PromptAgent +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + # import wandb +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() # Logger Configs {{{ # -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) - -datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - -file_handler = logging.FileHandler( - os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" -) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) -stdout_handler = logging.StreamHandler(sys.stdout) -sdebug_handler = logging.FileHandler( - os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" -) - -file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) -stdout_handler.setLevel(logging.INFO) -sdebug_handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" -) -file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) -stdout_handler.setFormatter(formatter) -sdebug_handler.setFormatter(formatter) - -stdout_handler.addFilter(logging.Filter("desktopenv")) -sdebug_handler.addFilter(logging.Filter("desktopenv")) - -logger.addHandler(file_handler) -logger.addHandler(debug_handler) -logger.addHandler(stdout_handler) -logger.addHandler(sdebug_handler) -# }}} Logger Configs # - -logger = logging.getLogger("desktopenv.experiment") - - def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -77,11 +43,9 @@ def config() -> argparse.Namespace: parser.add_argument( "--observation_type", choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], - default="a11y_tree", + default="screenshot", help="Observation type", ) - parser.add_argument("--screen_width", type=int, default=1920) - parser.add_argument("--screen_height", type=int, default=1080) parser.add_argument("--sleep_after_execution", type=float, default=0.0) parser.add_argument("--max_steps", type=int, default=15) @@ -97,7 +61,7 @@ def config() -> argparse.Namespace: parser.add_argument("--top_p", type=float, default=0.9) parser.add_argument("--max_tokens", type=int, default=1500) parser.add_argument("--stop_token", type=str, default=None) - + # example config parser.add_argument("--domain", type=str, default="all") parser.add_argument( @@ -106,112 +70,117 @@ def config() -> argparse.Namespace: # logging related parser.add_argument("--result_dir", type=str, default="./results") - parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") - + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") # aws config parser.add_argument( "--region", type=str, default="us-east-1", help="AWS region for the VM" ) - + parser.add_argument( + "--provider_name", type=str, default="docker", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) args = parser.parse_args() return args +args = config() # Get command line arguments first -def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: - """Distribute tasks evenly across environments.""" - # Flatten the tasks into a single list +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: all_tasks = [] for domain, examples in test_all_meta.items(): for example_id in examples: all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") - # Calculate tasks per environment - tasks_per_env = math.ceil(len(all_tasks) / num_envs) + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) - # Distribute tasks - distributed_tasks = [] - for i in range(num_envs): - env_tasks = {} - start_idx = i * tasks_per_env - end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) - - for domain, example_id in all_tasks[start_idx:end_idx]: - if domain not in env_tasks: - env_tasks[domain] = [] - env_tasks[domain].append(example_id) - - distributed_tasks.append(env_tasks) - - return distributed_tasks - - - -def run_env_tasks(env_idx: int, env: DesktopEnv, agent: PromptAgent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): - """Run tasks for a single environment.""" - logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") - - for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): - for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Env {env_idx+1}][Domain]: {domain}") - logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") - logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - + # Close environment in the current process context + for env in active_environments: + if env is not None: try: - lib_run_single.run_single_example( - agent, - env, - example, - args.max_steps, - example["instruction"], - args, - example_result_dir, - shared_scores, - ) + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") except Exception as e: - logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"Time limit exceeded in {domain}/{example_id}"} - ) - ) - f.write("\n") + logger.error(f"Process {env_idx + 1} error closing environment: {e}") - env.close() + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - logger.info("Args: %s", args) - - distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) - - # First, set up all environments - logger.info("Setting up all environments...") - envs = [] - agents = [] - - for env_idx in range(args.num_envs): - logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") - +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) agent = PromptAgent( model=args.model, max_tokens=args.max_tokens, @@ -220,48 +189,188 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None: action_space=args.action_space, observation_type=args.observation_type, max_trajectory_length=args.max_trajectory_length, + client_password=args.client_password ) - agents.append(agent) - env = DesktopEnv( - path_to_vm=args.path_to_vm, - action_space=agent.action_space, + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") - provider_name="aws", - region="us-east-1", - snapshot_name="ami-05e7d7bd279ea4f14", - screen_size=(args.screen_width, args.screen_height), - headless=args.headless, - os_type="Ubuntu", - require_a11y_tree=args.observation_type - in ["a11y_tree", "screenshot_a11y_tree", "som"], - ) - envs.append(env) +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes - logger.info("All environments are ready. Starting parallel task execution...") + # Avoid duplicate handling + if is_terminating: + return - # Create a shared list for scores across processes + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") with Manager() as manager: shared_scores = manager.list() - - # Create and start processes for each environment + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs processes = [] - for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + for i in range(num_envs): p = Process( target=run_env_tasks, - args=(env_idx, env, agent, env_tasks, args, shared_scores) + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" ) - processes.append(p) + p.daemon = True p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - # Convert shared list to regular list + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise scores = list(shared_scores) - logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") @@ -340,47 +449,89 @@ def get_result(action_space, use_model, observation_type, result_dir, total_file if __name__ == "__main__": ####### The complete version of the list of examples ####### os.environ["TOKENIZERS_PARALLELISM"] = "false" - - args = config() - # save args to json in result_dir/action_space/observation_type/model/args.json - path_to_args = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - "args.json", - ) - os.makedirs(os.path.dirname(path_to_args), exist_ok=True) - with open(path_to_args, "w", encoding="utf-8") as f: - json.dump(vars(args), f, indent=4) + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list) - - -# path_to_vm can be a list["xxx","xxx"] \ No newline at end of file + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") diff --git a/run_multienv_o3.py b/run_multienv_o3.py new file mode 100644 index 0000000..cd14da1 --- /dev/null +++ b/run_multienv_o3.py @@ -0,0 +1,529 @@ +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +from multiprocessing import current_process +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.o3_agent import O3Agent + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# import wandb + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="o3") + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + + +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + agent = O3Agent( + max_steps=args.max_steps, + client_password=args.client_password, + action_space=args.action_space, + observation_type=args.observation_type, + ) + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") + with Manager() as manager: + shared_scores = manager.list() + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs + processes = [] + for i in range(num_envs): + p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" + ) + p.daemon = True + p.start() + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + scores = list(shared_scores) + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") diff --git a/run_multienv_uitars15.py b/run_multienv_uitars15.py index c041b3b..5d25fd8 100644 --- a/run_multienv_uitars15.py +++ b/run_multienv_uitars15.py @@ -207,7 +207,6 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li max_tokens=args.max_tokens, top_p=args.top_p, temperature=args.temperature, - max_trajectory_length=args.max_trajectory_length, max_image_history_length=args.max_image_history_length, use_thinking=args.use_thinking,