diff --git a/lib_run_single.py b/lib_run_single.py index edc1900..f66945a 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -62,3 +62,58 @@ def setup_logger(example, example_result_dir): runtime_logger.setLevel(logging.DEBUG) runtime_logger.addHandler(logging.FileHandler(os.path.join(example_result_dir, "runtime.log"))) return runtime_logger + +def run_single_example_openaicua(agent, env, example, max_steps, instruction, args, example_result_dir, scores): + runtime_logger = setup_logger(example, example_result_dir) + agent.reset(runtime_logger) + env.reset(task_config=example) + time.sleep(60) # Wait for the environment to be ready + obs = env._get_obs() # Get the initial observation + done = False + step_idx = 0 + env.controller.start_recording() + while not done and step_idx < max_steps: + response, actions = agent.predict( + instruction, + obs + ) + + done = not response.get('state_correct', False) + + for action in actions: + # Capture the timestamp before executing the action + action_timestamp = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + logger.info("Step %d: %s", step_idx + 1, action) + obs, reward, done, info, step_info = agent.step(action) + + if not done: + if not response.get('state_correct', False): + done = True + + logger.info("Reward: %.2f", reward) + logger.info("Done: %s", done) + # Save screenshot and trajectory information + with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), + "wb") as _f: + _f.write(obs['screenshot']) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write(json.dumps({ + "step_num": step_idx + 1, + "action_timestamp": action_timestamp, + "action": action, + "reward": reward, + "done": done, + "info": info, + "screenshot_file": f"step_{step_idx + 1}_{action_timestamp}.png" + })) + f.write("\n") + if done: + logger.info("The episode is done.") + break + step_idx += 1 + result = env.evaluate() + logger.info("Result: %.2f", result) + scores.append(result) + with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: + f.write(f"{result}\n") + env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) \ No newline at end of file diff --git a/mm_agents/openai_cua_agent.py b/mm_agents/openai_cua_agent.py new file mode 100644 index 0000000..24670bc --- /dev/null +++ b/mm_agents/openai_cua_agent.py @@ -0,0 +1,759 @@ +import base64 +import json +import logging +import os +import re +import tempfile +import time +import xml.etree.ElementTree as ET +from http import HTTPStatus +from io import BytesIO +from typing import Dict, List + +import backoff +import dashscope +import google.generativeai as genai +import openai +import requests +import tiktoken +from PIL import Image +from google.api_core.exceptions import InvalidArgument, ResourceExhausted, InternalServerError, BadRequest +from groq import Groq +from requests.exceptions import SSLError +from typing import Any, Optional, Union, Tuple + +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes +from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \ + SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \ + SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \ + SYS_PROMPT_IN_SOM_OUT_TAG + +logger = logging.getLogger("desktopenv.agent") + +pure_text_settings = ['a11y_tree'] + +attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes" +attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes" +state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state" +state_ns_windows = "https://accessibility.windows.example.org/ns/state" +component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component" +component_ns_windows = "https://accessibility.windows.example.org/ns/component" +value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value" +value_ns_windows = "https://accessibility.windows.example.org/ns/value" +class_ns_windows = "https://accessibility.windows.example.org/ns/class" +# More namespaces defined in OSWorld, please check desktop_env/server/main.py +import ast +from typing import Dict, Any, Optional, Union + +class Action: + """Action class for the agent.""" + def __init__(self, raw_action: Union[Dict, str], action_space: str): + """Initialize the Action class. + + Args: + raw_action: The raw action + action_space: The action space + """ + self._action_space = None + self._action = None + self.action_space = action_space + self.action = raw_action + + @property + def action(self) -> str: + return self._action + + @property + def action_space(self) -> str: + return self._action_space + + @action_space.setter + def action_space(self, value: str): + """ + Set the action space for the agent. + Currently only supports 'pyautogui' as a valid action space. + + Args: + value (str): The action space to set + + Raises: + ValueError: If action_space is empty or invalid + """ + if not value: + raise ValueError("action_space is required") + if value not in ["pyautogui", "claude_computer_use"]: + raise ValueError( + "Invalid action space. Allowed spaces are: pyautogui") + self._action_space = value + + + + @action.setter + def action(self, value: Optional[str]): + """ + Set the action for the agent. + For pyautogui action space, accepts special commands (WAIT, FAIL, DONE) or valid Python code. + For claude_computer_use action space, accepts a dict with keys "name", "input" and "id". + + Args: + value (str | dict): The action to set + + Raises: + ValueError: If action is empty or invalid + """ + if not value: + raise ValueError("action cannot be empty") + + if self._action_space == "pyautogui": + self._action = value + # if value in ["WAIT", "FAIL", "DONE"]: + # self._action = value + # elif self._is_valid_python_code(value): + # self._action = value + # else: + # raise ValueError("Invalid action format for pyautogui") + elif self._action_space == "claude_computer_use": + self._action = value + # if self._is_valid_claude_computer_use_action(value): + # self._action = value + else: + raise ValueError( + f"Invalid action space: {self._action_space}, allowed spaces are: pyautogui, claude_computer_use") + + def __str__(self) -> str: + """Return a string representation of the Action instance. + + Returns: + str: A string showing the action space and action value + """ + return f"Action(action_space='{self._action_space}', action='{self._action}')" + + def get_action(self) -> Optional[str]: + """Get the action. + + Returns: + str: The action + """ + return self._action + + def to_dict(self) -> Dict[str, Any]: + """Convert the action to a dictionary. + + Returns: + dict: The action as a dictionary + """ + return {"action_space": self._action_space, "action": self._action} + + def _is_valid_python_code(self, code: str) -> bool: + """ + Validate if the given string is valid Python code syntax. + + Args: + code (str): The code string to validate + + Returns: + bool: True if code is valid Python syntax, False otherwise + """ + try: + ast.parse(code) + return True + except SyntaxError: + raise ValueError("Invalid Python code syntax") + + def _is_valid_claude_computer_use_action(self, action: Dict[str, Any]) -> bool: + """Validate if the given action is valid for the claude_computer_use action space. + + Args: + action: The action to validate + + Returns: + bool: True if action is valid, False otherwise + """ + if not isinstance(action, dict): + raise ValueError("Invalid action format for claude_computer_use") + if not (action.get("name") and action.get("input") and action.get("id")): + raise ValueError( + "Invalid action format for claude_computer_use, 'name', 'input' and 'id' are required") + return True + +class Timer: + """Context manager for timing code blocks.""" + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.duration = time.time() - self.start + +# Function to encode the image +def encode_image(image_content): + return base64.b64encode(image_content).decode('utf-8') + + +def encoded_img_to_pil_img(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + return image + + +def save_to_tmp_img_file(data_str): + base64_str = data_str.replace("data:image/png;base64,", "") + image_data = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_data)) + + tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png") + image.save(tmp_img_path) + + return tmp_img_path + + +class OpenAICUAAgent: + def __init__( + self, + env, + platform="ubuntu", + model="computer-use-preview", + max_tokens=1500, + top_p=0.9, + temperature=0.5, + action_space="pyautogui", + observation_type="screenshot_a11y_tree", + # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] + max_trajectory_length=100, + a11y_tree_max_tokens=10000 + ): + self.env = env + self.platform = platform + self.model = model + self.max_tokens = max_tokens + self.top_p = top_p + self.temperature = temperature + self.action_space = action_space + self.observation_type = observation_type + self.max_trajectory_length = max_trajectory_length + self.a11y_tree_max_tokens = a11y_tree_max_tokens + self.cua_messages : List[Dict] = [] + + self.thoughts = [] + self.actions = [] + self.observations = [] + + self.tools = [{ + "type": "computer_use_preview", + "display_width": 1920, + "display_height": 1080, + "environment": "linux" if platform == "ubuntu" else "windows" + }] + + if observation_type == "screenshot": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif observation_type == "a11y_tree": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif observation_type == "screenshot_a11y_tree": + if action_space == "computer_13": + self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE + else: + raise ValueError("Invalid action space: " + action_space) + elif observation_type == "som": + if action_space == "computer_13": + raise ValueError("Invalid action space: " + action_space) + elif action_space == "pyautogui": + self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG + else: + raise ValueError("Invalid action space: " + action_space) + else: + raise ValueError("Invalid experiment type: " + observation_type) + + def _create_response(self, **kwargs: Any) -> Dict[str, Any]: + """Create a response from the OpenAI API. + + Args: + **kwargs: Additional arguments to pass to the API + + Returns: + The API response as a dictionary + + Raises: + requests.exceptions.RequestException: If the API request fails + """ + retry_count = 0 + while retry_count < 3: + try: + from openai import OpenAI + client = OpenAI(api_key=os.getenv("OPENAI_API_KEY_CUA")) + response = client.responses.create( + model=self.model, + input=self.cua_messages, + tools=self.tools, + reasoning={ + "generate_summary": "concise", + }, + truncation="auto", + ) + logger.debug(f"Received successful response from OpenAI API") + logger.info(f"Response: {response}") + return response + except Exception as e: + logger.error(f"OpenAI API error: {str(e)}") + new_screenshot = self.env._get_obs() + new_screenshot_base64 = base64.b64encode(new_screenshot["screenshot"]).decode('utf-8') + self.cua_messages[-1]["output"]["image_url"] = f"data:image/png;base64,{new_screenshot_base64}" + retry_count += 1 + time.sleep(1) + raise Exception("Failed to make OpenAI API call after 3 retries") + + def _handle_item(self, item: Dict[str, Any]) -> Optional[Union[str, Dict[str, Any]]]: + """Parse a response item from the OpenAI API. + + Args: + item: The response item to parse + + Returns: + The parsed item as either a string message or a dictionary containing action information, + or None if the item couldn't be parsed + """ + if item.type == "message": + if item.content is not None: + response = item.content[0] if isinstance(item.content, list) else item.content + response_type = response.type + response_text = response.text + logger.info(f"Received response text: {response_type} - {response_text}") + if response_type == "output_text": + return response_text + return None + return None + + if item.type == "function_call": + return None + + if item.type == "reasoning": + reasoning = item.summary + if isinstance(reasoning, list): + reasoning_item = reasoning[0] + reasoning_text = reasoning_item.text + reasoning_type = reasoning_item.type + if reasoning_type == "summary_text": + return reasoning_text + return None + return None + + if item.type == "computer_call": + action = item.action + action_type = action.type + # Convert object attributes to dictionary + action_args = {} + for attr in dir(action): + if attr.startswith('_') or attr == 'type': + continue + try: + action_args[attr] = getattr(action, attr) + except AttributeError: + pass + logger.warning(f"Original Action: {action}") + result_code = self._convert_cua_action_to_pyautogui_action(action_type, action_args) + if result_code: + return { + "action_space": "pyautogui", + "action": result_code, + "pending_checks": item.pending_safety_checks, + "call_id": item.call_id + } + return None + + def _convert_cua_action_to_pyautogui_action(self, action_type, args): + """Convert a CUA action to a pyautogui action format + + This function converts OpenAI CUA actions to pyautogui commands + for the Computer Agent Arena + + Args: + action_type: Type of the CUA action + args: Arguments for the action + + Returns: + String with pyautogui command code or None if the action can't be converted + """ + if not action_type: + logger.warning("Empty CUA action received") + return None + + key_mapping = { + "/": "/", + "\\": "\\", + "alt": "alt", + "arrowdown": "down", + "arrowleft": "left", + "arrowright": "right", + "arrowup": "up", + "backspace": "backspace", + "capslock": "capslock", + "cmd": "command", + "ctrl": "ctrl", + "delete": "delete", + "end": "end", + "enter": "enter", + "esc": "esc", + "home": "home", + "insert": "insert", + "option": "option", + "pagedown": "pagedown", + "pageup": "pageup", + "shift": "shift", + "space": "space", + "super": "super", + "tab": "tab", + "win": "win", + } + try: + if action_type == "click": + x = args.get("x") + y = args.get("y") + button = args.get("button", "left") + + # Validate coordinates + if x is None or y is None: + logger.warning(f"Invalid click coordinates: x={x}, y={y}") + return None + + # Validate button + if button not in ["left", "middle", "right"]: + logger.warning(f"Invalid click button: {button}, defaulting to 'left'") + button = "left" + + return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.click(button='{button}')" + + elif action_type == "double_click": + x = args.get("x") + y = args.get("y") + + # Validate coordinates + if x is None or y is None: + logger.warning(f"Invalid double_click coordinates: x={x}, y={y}") + return None + + return f"import pyautogui\npyautogui.moveTo({x}, {y})\npyautogui.doubleClick()" + + elif action_type == "type": + text = args.get("text", "") + + if not text: + logger.warning("Empty text for type action") + return "import pyautogui\n# Empty text, no action taken" + + pattern = r"(? List: + """ + Predict the next action(s) based on the current observation. + """ + + base64_image = encode_image(obs["screenshot"]) + if self.cua_messages == []: + self.cua_messages.append({ + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/png;base64,{base64_image}", + }, + { + "type": "input_text", + "text": instruction + } + ] + }) + + with Timer() as model_timer: + response = self._create_response() + self.cua_messages += response.output + + actions = [] + responses = [] + action_exit = False + thought_exit = False + message_exit = False + for item in response.output: + parsed_item = self._handle_item(item) + if isinstance(parsed_item, dict) and parsed_item.get("action_space", None) == "pyautogui": + actions.append(parsed_item) + else: + responses.append(parsed_item) + if item.type == "computer_call": + action_exit = True + if item.type == "reasoning" and item.summary and item.summary[0].type == "summary_text": + thought_exit = True + if item.type == "message" and item.content and item.content[0].type == "output_text": + message_exit = True + responses = [item for item in responses if item is not None] + + logger.info(f"Actions: {actions}") + logger.info(f"Responses: {responses}") + + state_correct = False + # if action_exit and thought_exit: + # state_correct = True + if action_exit and not message_exit: + state_correct = True + if not state_correct: + logger.warning("The state of the agent is not correct, action_exit: %s, thought_exit: %s, message_exit: %s", action_exit, thought_exit, message_exit) + + + predict_info = { + "model_usage": { + "model_time": model_timer.duration, + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + }, + "messages": self.cua_messages, + "response": "\n".join(responses) if isinstance(responses, list) and all(isinstance(item, str) for item in responses) else "", + "state_correct": state_correct, + } + + return predict_info, actions + + + def reset(self, _logger=None): + global logger + logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent") + + self.thoughts = [] + self.actions = [] + self.observations = [] + self.cua_messages = [] + + def step(self, action: Dict[str, Any]) -> Tuple[bool, Dict[str, Any]]: + """Execute an action in the environment. + + Args: + action: The action to execute + + Returns: + Tuple containing: + - terminated: Whether the episode has terminated + - info: Information about the step + + Raises: + StepError: If the step execution fails + """ + try: + if not action: + logger.warning("Empty action received, terminating episode") + return True, {} + + logger.info(f"Executing action: {action.get('action_space', 'unknown')} - {action.get('action', '')[:50]}...") + + with Timer() as step_timer: + # Convert the action to an Action object + step_action = Action(action.get("action", ""), self.action_space) + # Execute the action in the environment + obs, reward, terminated, info = self.env.step(step_action.get_action()) + + screenshot_base64 = encode_image(obs["screenshot"]) + + self.cua_messages.append({ + "type": "computer_call_output", + "call_id": action["call_id"], + "acknowledged_safety_checks": action["pending_checks"], + "output": { + "type": "input_image", + "image_url": f"data:image/png;base64,{screenshot_base64}", + }, + }) + + logger.debug(f"Action completed in {step_timer.duration:.2f}s") + if terminated: + logger.info("Environment signaled termination") + + return obs, reward, terminated, info, { + "step_time": step_timer.duration, + "action": action + } + + except Exception as e: + logger.exception(f"Environment step failed: {str(e)}") + raise StepError(f"Failed to execute step: {str(e)}") + +class StepError(Exception): + """Exception raised when a step in the agent fails.""" + pass diff --git a/run_multienv_openaicua.py b/run_multienv_openaicua.py new file mode 100644 index 0000000..342ce91 --- /dev/null +++ b/run_multienv_openaicua.py @@ -0,0 +1,357 @@ +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" + +import argparse +import datetime +import json +import logging +import os +import sys +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.openai_cua_agent import OpenAICUAAgent + +# import wandb + + +# Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument("--sleep_after_execution", type=float, default=0.0) + parser.add_argument("--max_steps", type=int, default=15) + + # agent config + parser.add_argument("--max_trajectory_length", type=int, default=3) + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="gpt-4o") + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + args = parser.parse_args() + return args + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + +def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + + provider_name="aws", + region="us-east-1", + snapshot_name="ami-05e7d7bd279ea4f14", + + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + ) + agent = OpenAICUAAgent( + env=env, + model=args.model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + ) + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + try: + lib_run_single.run_single_example_openaicua( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) + f.write("\n") + + env.close() + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, env_tasks in enumerate(distributed_tasks): + p = Process( + target=run_env_tasks, + args=(env_idx, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list)