From 55372c4432d645d1ead672f0364a06722c49164c Mon Sep 17 00:00:00 2001 From: Dunjie Lu <127488745+ludunjie1219@users.noreply.github.com> Date: Tue, 14 Oct 2025 12:57:00 +0800 Subject: [PATCH 1/3] Fix API base URLs for OpenAI and DashScope Updated the base URLs for OpenAI and DashScope API calls. --- mm_agents/qwen3vl_agent.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mm_agents/qwen3vl_agent.py b/mm_agents/qwen3vl_agent.py index e0ee85f..86483b3 100644 --- a/mm_agents/qwen3vl_agent.py +++ b/mm_agents/qwen3vl_agent.py @@ -628,7 +628,7 @@ Previous actions: def _call_llm_openai(self, messages, model): """Call LLM using OpenAI SDK (compatible with OpenAI-compatible endpoints).""" - base_url = "https://poc-dashscope.aliyuncs.com/compatible-mode/v1" + base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" api_key = "sk-123" client = openai.OpenAI(base_url=base_url, api_key=api_key) @@ -653,7 +653,7 @@ Previous actions: def _call_llm_dashscope(self, messages, model): """Call LLM using DashScope SDK.""" - dashscope.base_http_api_url = "https://poc-dashscope.aliyuncs.com/api/v1" + dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1" dashscope.api_key = "sk-123" # Convert message schema From afd29115da0c583fb73cb1126775184e9c4ee9c5 Mon Sep 17 00:00:00 2001 From: "ludunjie.ldj" Date: Thu, 16 Oct 2025 16:20:54 +0800 Subject: [PATCH 2/3] support aliyun eval of qwen3vl --- mm_agents/qwen3vl_agent.py | 14 +++++++------- run_multienv_qwen3vl.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mm_agents/qwen3vl_agent.py b/mm_agents/qwen3vl_agent.py index 86483b3..900892f 100644 --- a/mm_agents/qwen3vl_agent.py +++ b/mm_agents/qwen3vl_agent.py @@ -61,7 +61,7 @@ class Qwen3VLAgent: self, platform: str = "ubuntu", model: str = "qwen3-vl", - max_tokens: int = 40960, + max_tokens: int = 32768, top_p: float = 0.9, temperature: float = 0.0, action_space: str = "pyautogui", @@ -70,7 +70,7 @@ class Qwen3VLAgent: add_thought_prefix: bool = False, coordinate_type: str = "relative", api_backend: str = "dashscope", # "openai" or "dashscope" - enable_thinking: bool = True, # Enable thinking mode for DashScope + enable_thinking: bool = False, # Enable thinking mode for DashScope thinking_budget: int = 32768, # Token budget for reasoning ): self.platform = platform @@ -628,8 +628,8 @@ Previous actions: def _call_llm_openai(self, messages, model): """Call LLM using OpenAI SDK (compatible with OpenAI-compatible endpoints).""" - base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" - api_key = "sk-123" + base_url = os.environ.get("OPENAI_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1") + api_key = os.environ.get("OPENAI_API_KEY", "sk-123") client = openai.OpenAI(base_url=base_url, api_key=api_key) for attempt in range(1, MAX_RETRY_TIMES + 1): @@ -653,8 +653,8 @@ Previous actions: def _call_llm_dashscope(self, messages, model): """Call LLM using DashScope SDK.""" - dashscope.base_http_api_url = "https://dashscope.aliyuncs.com/api/v1" - dashscope.api_key = "sk-123" + dashscope.base_http_api_url = os.environ.get("DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/api/v1") + dashscope.api_key = os.environ.get("DASHSCOPE_API_KEY", "sk-123") # Convert message schema ds_messages = self._to_dashscope_messages(messages) @@ -669,7 +669,7 @@ Previous actions: call_params = { "model": model, "messages": ds_messages, - "max_tokens": min(self.max_tokens, 2048), + "max_tokens": self.max_tokens, # "temperature": self.temperature, # "top_p": self.top_p, "vl_high_resolution_images": True, diff --git a/run_multienv_qwen3vl.py b/run_multienv_qwen3vl.py index 2d9c0f6..d9e0c26 100644 --- a/run_multienv_qwen3vl.py +++ b/run_multienv_qwen3vl.py @@ -57,7 +57,7 @@ def config() -> argparse.Namespace: parser.add_argument("--model", type=str, default="qwen3-vl") parser.add_argument("--temperature", type=float, default=0) parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--max_tokens", type=int, default=40960) + parser.add_argument("--max_tokens", type=int, default=32768) parser.add_argument("--stop_token", type=str, default=None) parser.add_argument( "--coord", @@ -99,7 +99,7 @@ def config() -> argparse.Namespace: "--provider_name", type=str, default="docker", - choices=["aws", "virtualbox", "vmware", "docker", "azure"], + choices=["aws", "virtualbox", "vmware", "docker", "azure", "aliyun"], help="Provider name", ) parser.add_argument( From 9f97535ef99337da2518b903ec621f7e11e657b3 Mon Sep 17 00:00:00 2001 From: Atharva Gundawar <54273198+Atharva-Gundawar@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:29:15 -0700 Subject: [PATCH 3/3] oswrold agent wrapper for trained v7 (#360) --- lib_run_single.py | 61 +++++ mm_agents/agi_agent.py | 219 +++++++++++++++++ run_multienv_agi.py | 542 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 822 insertions(+) create mode 100644 mm_agents/agi_agent.py create mode 100644 run_multienv_agi.py diff --git a/lib_run_single.py b/lib_run_single.py index 751f676..809ab31 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -97,6 +97,67 @@ def run_single_example_human(env, example, max_steps, instruction, args, example +def run_single_example_agi(agent, env, example, max_steps, instruction, args, example_result_dir, scores): + runtime_logger = setup_logger(example, example_result_dir) + agent.reset(runtime_logger) + env.reset(task_config=example) + time.sleep(60) # Wait for the environment to be ready + obs = env._get_obs() # Get the initial observation + done = False + step_idx = 0 + env.controller.start_recording() + while not done and step_idx < max_steps: + response, actions = agent.predict( + instruction, + obs + ) + + done = not response.get('state_correct', False) + + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + obs, reward, done, info, step_info = agent.step(action) + + if not done: + if not response.get('state_correct', False): + done = True + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + _f.write(obs['screenshot']) + + # Remove pending checks if they exist which will cause issues with json serialization + if action.get('pending_checks', None): + del action['pending_checks'] + + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + if done: + logger.info("The episode is done.") + break + step_idx += 1 + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) + + def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores): runtime_logger = setup_logger(example, example_result_dir) agent.reset(runtime_logger) diff --git a/mm_agents/agi_agent.py b/mm_agents/agi_agent.py new file mode 100644 index 0000000..3f88cf1 --- /dev/null +++ b/mm_agents/agi_agent.py @@ -0,0 +1,219 @@ +import base64 +import logging +import time +from typing import Dict, List, Tuple, Any, Optional + +import httpx + +logger = logging.getLogger("desktopenv.agent") + + +class Timer: + """Context manager for timing code blocks.""" + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.duration = time.time() - self.start + + +class AGIAgent: + """Agent that communicates with your private AGI server for decision-making.""" + + def __init__( + self, + env, + server_url: str = "https://your-private-agi-endpoint", # Contact the authors for access to a private deployment endpoint. + platform: str = "ubuntu", + action_space: str = "pyautogui", + observation_type: str = "screenshot", + max_trajectory_length: int = 100, + client_password: str = "", + provider_name: str = "aws", + screen_width: int = 1920, + screen_height: int = 1080, + timeout: int = 1800, + ): + """Initialize the AGI client. + + Args: + env: The desktop environment + server_url: URL of your private AGI server + """ + self.env = env + self.server_url = server_url.rstrip("/") + self.platform = platform + self.action_space = action_space + self.observation_type = observation_type + self.max_trajectory_length = max_trajectory_length + self.client_password = client_password + self.provider_name = provider_name + self.screen_width = screen_width + self.screen_height = screen_height + + # Session management + self.session_id: Optional[str] = None + self.instruction: Optional[str] = None + + # HTTP client + self.client = httpx.Client(timeout=timeout) + + # Tracking + self.thoughts = [] + self.actions = [] + self.observations = [] + + logger.info(f"Initialized AGIAgent with server URL: {self.server_url}") + + def reset(self, runtime_logger=None): + """Reset the agent and create a new session on the server. + + Args: + runtime_logger: Optional logger for runtime information + """ + global logger + logger = runtime_logger if runtime_logger is not None else logging.getLogger("desktopenv.agent") + + # Clear local state + self.thoughts = [] + self.actions = [] + self.observations = [] + self.session_id = None + + logger.info("AGIAgent reset complete") + + def _create_session(self, instruction: str) -> str: + """Create a new session on the server. + + Args: + instruction: The task instruction + + Returns: + The session ID + + Equivalent curl request: + curl -X POST {server_url}/sessions \ + -H "Content-Type: application/json" \ + -d '{"task_description": "{instruction}"}' + """ + try: + # print(f"Creating session with instruction: {instruction}") + # print(f"Server URL: {self.server_url}") + response = self.client.post( + f"{self.server_url}/sessions", + json={"task_description": instruction} + ) + response.raise_for_status() + session_id = response.json()["session_id"] + logger.info(f"Created session: {session_id}") + return session_id + except Exception as e: + logger.error(f"Failed to create session: {e}") + raise + + def predict(self, instruction: str, obs: Dict) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: + """Predict the next action based on the current observation. + + Args: + instruction: The task instruction + obs: Observation dictionary containing 'screenshot' key with image bytes + + Returns: + Tuple of (predict_info dict, list of action dicts) + """ + # Create session on first prediction + if self.session_id is None: + self.instruction = instruction + self.session_id = self._create_session(instruction) + + # input("Session created, press Enter to continue") + + # Encode screenshot to base64 + screenshot_bytes = obs["screenshot"] + screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8") + + # Call the server + with Timer() as model_timer: + try: + response = self.client.post( + f"{self.server_url}/sessions/{self.session_id}/step", + json={ + "screenshot_base64_png": screenshot_b64, + "error": None # Could be populated from previous step errors + } + ) + response.raise_for_status() + result = response.json() + parsed_action = result["parsed_response"] + + logger.info(f"Server returned action: {parsed_action[:100]}...") + + except Exception as e: + logger.error(f"Error calling server: {e}") + raise + + # Format response as expected by lib_run_single + actions = [{ + "action_space": "pyautogui", + "action": parsed_action, + "pending_checks": [], + "call_id": "" + }] + + # Check if task is complete or failed + state_correct = parsed_action not in ["FAIL", "DONE"] + + predict_info = { + "model_usage": { + "model_time": model_timer.duration, + "prompt_tokens": 0, # Server doesn't expose these + "completion_tokens": 0, + }, + "messages": [], # Server manages conversation history + "response": parsed_action, + "state_correct": state_correct, + } + + return predict_info, actions + + def step(self, action: Dict[str, Any]) -> Tuple[Dict, float, bool, Dict, Dict]: + """Execute an action in the environment. + + Args: + action: Action dictionary with 'action' key containing PyAutoGUI command + + Returns: + Tuple of (observation, reward, done, info, step_info) + """ + try: + if not action: + logger.warning("Empty action received, terminating episode") + # Get observation without executing action + obs = self.env._get_obs() + return obs, 0.0, True, {}, {"step_time": 0.0, "action": action} + + action_str = action.get("action", "") + logger.info(f"Executing action: {action_str[:100]}...") + + with Timer() as step_timer: + # Execute the action directly (it's already a PyAutoGUI command string) + obs, reward, terminated, info = self.env.step(action_str) + + logger.debug(f"Action completed in {step_timer.duration:.2f}s") + if terminated: + logger.info("Environment signaled termination") + + return obs, reward, terminated, info, { + "step_time": step_timer.duration, + "action": action + } + + except Exception as e: + logger.exception(f"Environment step failed: {str(e)}") + raise + + def close(self): + """Close the HTTP client.""" + self.client.close() diff --git a/run_multienv_agi.py b/run_multienv_agi.py new file mode 100644 index 0000000..0888982 --- /dev/null +++ b/run_multienv_agi.py @@ -0,0 +1,542 @@ +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +from multiprocessing import current_process +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.agi_agent import AGIAgent + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# import wandb + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument("--max_trajectory_length", type=int, default=3) + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # example config + parser.add_argument("--domain", type=str, nargs='+', default=["all"], + help="Domain(s) to run. Use 'all' for all domains, or specify one or more domain names") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="osworld-public-evaluation", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + + +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + agent = AGIAgent( + env=env, + # Contact the authors for access to a private deployment endpoint. + server_url="https://your-private-agi-endpoint", + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + client_password=args.client_password, + provider_name=args.provider_name, + screen_width=args.screen_width, + screen_height=args.screen_height + ) + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + "agi-0", + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example_agi( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") + with Manager() as manager: + shared_scores = manager.list() + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs + processes = [] + for i in range(num_envs): + p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" + ) + p.daemon = True + p.start() + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + scores = list(shared_scores) + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + "agi-0", + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + # Handle multiple domains + if "all" not in args.domain: + # Filter test_all_meta to only include specified domains + filtered_meta = {} + for domain in args.domain: + if domain in test_all_meta: + filtered_meta[domain] = test_all_meta[domain] + else: + logger.warning(f"Domain '{domain}' not found in test_all_meta") + test_all_meta = filtered_meta + + test_file_list = get_unfinished( + args.action_space, + "agi-0", + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + "agi-0", + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}")