diff --git a/mm_agents/seed16.py b/mm_agents/seed16.py
new file mode 100644
index 0000000..f32a08a
--- /dev/null
+++ b/mm_agents/seed16.py
@@ -0,0 +1,696 @@
+
+import os
+import re
+import base64
+import requests
+import logging
+from typing import Optional, Dict, List, Tuple, Union
+from loguru import logger
+from ui_tars.action_parser import parse_xml_action, parsing_response_to_pyautogui_code, parse_xml_action_v3
+import ast
+import base64
+import json
+import math
+import io
+import re
+from PIL import Image
+
+FINISH_WORD = "finished"
+WAIT_WORD = "wait"
+ENV_FAIL_WORD = "error_env"
+CALL_USER = "call_user"
+INFEASIBLE = "infeasible"
+
+GUI_TOOL_SCHEMAS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "click",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Click coordinates. The format is: x y"
+ }
+ },
+ "required": [
+ "point"
+ ]
+ },
+ "description": "Mouse left single click action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "left_double",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Click coordinates. The format is: x y"
+ }
+ },
+ "required": [
+ "point"
+ ]
+ },
+ "description": "Mouse left double click action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "right_single",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Click coordinates. The format is: x y"
+ }
+ },
+ "required": [
+ "point"
+ ]
+ },
+ "description": "Mouse right single click action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "drag",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "start_point": {
+ "type": "string",
+ "description": "Drag start point. The format is: x y"
+ },
+ "end_point": {
+ "type": "string",
+ "description": "Drag end point. The format is: x y"
+ }
+ },
+ "required": [
+ "start_point",
+ "end_point"
+ ]
+ },
+ "description": "Mouse left button drag action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "scroll",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Scroll start position. If not specified, default to execute on the current mouse position. The format is: x y"
+ },
+ "direction": {
+ "type": "string",
+ "description": "Scroll direction.",
+ "enum": [
+ "up",
+ "down",
+ "left",
+ "right"
+ ]
+ }
+ },
+ "required": [
+ "direction"
+ ]
+ },
+ "description": "Scroll action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "move_to",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Target coordinates. The format is: x y"
+ }
+ },
+ "required": [
+ "point"
+ ]
+ },
+ "description": "Mouse move action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "mouse_down",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Mouse down position. If not specified, default to execute on the current mouse position. The format is: x y"
+ },
+ "button": {
+ "type": "string",
+ "description": "Down button. Default to left.",
+ "enum": [
+ "left",
+ "right"
+ ]
+ }
+ },
+ "required": []
+ },
+ "description": "Mouse down action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "mouse_up",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "point": {
+ "type": "string",
+ "description": "Mouse up position. If not specified, default to execute on the current mouse position. The format is: x y"
+ },
+ "button": {
+ "type": "string",
+ "description": "Up button. Default to left.",
+ "enum": [
+ "left",
+ "right"
+ ]
+ }
+ },
+ "required": []
+ },
+ "description": "Mouse up action."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "type",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "Type content. If you want to submit your input, use \n at the end of content."
+ }
+ },
+ "required": [
+ "content"
+ ]
+ },
+ "description": "Type content."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "hotkey",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "Hotkeys you want to press. Split keys with a space and use lowercase."
+ }
+ },
+ "required": [
+ "key"
+ ]
+ },
+ "description": "Press hotkey."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "press",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "Key you want to press. Only one key can be pressed at one time."
+ }
+ },
+ "required": [
+ "key"
+ ]
+ },
+ "description": "Press key."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "release",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "key": {
+ "type": "string",
+ "description": "Key you want to release. Only one key can be released at one time."
+ }
+ },
+ "required": [
+ "key"
+ ]
+ },
+ "description": "Release key."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "finished",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "Provide the final answer or response to complete the task."
+ }
+ },
+ "required": []
+ },
+ "description": "This function is used to indicate the completion of a task by providing the final answer or response."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "call_user",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "Message or information displayed to the user to request their input, feedback, or guidance."
+ }
+ },
+ "required": []
+ },
+ "description": "This function is used to interact with the user by displaying a message and requesting their input, feedback, or guidance."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "wait",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "time": {
+ "type": "integer",
+ "description": "Wait time in seconds."
+ }
+ },
+ "required": []
+ },
+ "description": "Wait for a while."
+ }
+ },
+ {
+ "type": "function",
+ "function": {
+ "name": "infeasible",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "content": {
+ "type": "string",
+ "description": "Message or information displayed to the user to explain why the current task is infeasible."
+ }
+ },
+ "required": ["content"]
+ },
+ "description": "This function is used to indicate that the current task is infeasible thus agent ends the task."
+ }
+ }
+]
+
+def modify_conversations(conversations):
+ new_conversations = []
+ for conversation in conversations:
+ if isinstance(conversation["content"], list):
+ if "type" in conversation["content"][0] and conversation["content"][0]["type"] == "image_url":
+ conversation["content"][0]["image_url"]["detail"] = "high"
+ new_conversations.append(conversation)
+ return new_conversations
+
+class Seed16Agent:
+ """
+ UI-TARS Agent based on Seed1.5-VL model implementation.
+ Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture.
+ """
+
+ def __init__(
+ self,
+ # Model settings
+ model: str,
+ model_type: str,
+ # Generation settings
+ max_tokens: int,
+ top_p: Optional[float],
+ temperature: float,
+
+ # History settings
+ max_trajectory_length: Optional[int],
+ history_n: Optional[int],
+
+ # Outside infos
+ max_steps: int = 100,
+
+ # UI-TARS specific settings
+ use_thinking: bool = True,
+ resize_image: bool = False,
+ resized_image_width: int = 1920,
+ resized_image_height: int = 1080,
+ ):
+ """
+ Initialize Seed16 Agent.
+
+ Args:
+ model: Model name, defaults to doubao-1-5-thinking-vision-pro-250428
+ api_key: API key for the model service
+ base_url: Base URL for the API service
+ max_tokens: Maximum tokens to generate
+ top_p: Top-p sampling parameter
+ temperature: Temperature for sampling
+ max_trajectory_length: Maximum trajectory history length
+ screenshot_pyautogui_prompt: Prompt version
+ max_steps: Maximum steps for the agent
+ use_thinking: Whether to use thinking mode
+ openai_client: OpenAI client instance
+ """
+
+ self.model = model
+ self.max_trajectory_length = max_trajectory_length
+ self.logger = logger
+ self.thoughts = []
+ self.actions = []
+ self.observations = []
+ self.history_images = []
+ self.history_responses = []
+
+ self.system_prompt = "You are provided with a task description, a history of previous actions, and corresponding screenshots. Your goal is to perform the next action to complete the task. Please note that if performing the same action multiple times results in a static screen with no changes, you should attempt a modified or alternative action."
+
+ self.action_parse_res_factor = 1000
+ self.model_type = model_type
+ self.history_n = history_n
+ self.top_p = top_p
+ self.temperature = temperature
+ self.max_tokens = max_tokens
+ self.platform = "ubuntu"
+ self.use_thinking = use_thinking
+
+ self.inference_func = self.inference_with_thinking
+ self.resize_image = resize_image
+ self.resized_image_width = resized_image_width
+ self.resized_image_height = resized_image_height
+ self.input_swap = False
+
+ def reset(self, _logger=None, vm_ip=None, **kwargs):
+ global logger
+ logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent")
+
+ self.vm_ip = vm_ip
+
+ self.thoughts = []
+ self.actions = []
+ self.observations = []
+ self.history_images = []
+ self.history_responses = []
+
+ def pretty_print_messages(self, messages):
+ """Pretty print messages while hiding base64 encoded images."""
+ def format_message(msg):
+ if not isinstance(msg, dict):
+ return str(msg)
+
+ formatted = {}
+ for key, value in msg.items():
+ if key == "content":
+ if isinstance(value, list):
+ formatted_content = []
+ for item in value:
+ if isinstance(item, dict) and "type" in item:
+ if item["type"] == "image_url" and "image_url" in item:
+ # Replace base64 image with placeholder
+ formatted_content.append({
+ "type": "image_url",
+ "image_url": {"url": "[BASE64_IMAGE_DATA]"}
+ })
+ else:
+ formatted_content.append(item)
+ else:
+ formatted_content.append(item)
+ formatted[key] = formatted_content
+ else:
+ formatted[key] = value
+ else:
+ formatted[key] = value
+ return formatted
+
+ if isinstance(messages, list):
+ return [format_message(msg) for msg in messages]
+ return format_message(messages)
+
+
+ def inference_with_thinking(self, messages):
+ api_key = os.environ['DOUBAO_API_KEY']
+ api_url = os.environ['DOUBAO_API_URL']
+ headers = {
+ 'Authorization': f'Bearer {api_key}',
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "model": self.model,
+ "messages": messages,
+ "max_tokens": self.max_tokens,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "reasoning_effort": "high"
+ }
+
+ response = requests.post(api_url, headers=headers, json=data)
+ print(response.json()["choices"][0])
+ if response.status_code == 200:
+ return response.json()["choices"][0]["message"]
+ else:
+ return {
+ "error": f"Request failed with status code {response.status_code}",
+ "details": response.text
+ }
+
+ def inference_without_thinking(self, messages):
+ api_key = os.environ['DOUBAO_API_KEY']
+ api_url = os.environ['DOUBAO_API_URL']
+ headers = {
+ 'Authorization': f'Bearer {api_key}',
+ 'Content-Type': 'application/json'
+ }
+ data = {
+ "model": self.model,
+ "messages": messages,
+ "thinking": {"type": "disabled"},
+ "max_tokens": self.max_tokens,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ }
+
+ response = requests.post(api_url, headers=headers, json=data)
+
+
+ if response.status_code == 200:
+ return response.json()["choices"][0]["message"]["content"]
+ else:
+ print(f"Request failed with status code {response.status_code}")
+ print(response.json())
+ return {
+ "error": f"Request failed with status code {response.status_code}",
+ "details": response.text
+ }
+
+ def predict(self, task_instruction: str, obs: dict) -> Tuple[Union[str, Dict, None], List]:
+ """Predict the next action based on the current observation."""
+
+ self.task_instruction = task_instruction + f"\nThe sudo password is osworld-public-evaluation"
+
+ assert len(self.observations) == len(self.actions) and len(self.actions) == len(
+ self.thoughts
+ ), "The number of observations and actions should be the same."
+
+ # Convert binary screenshot to base64 if needed
+ screenshot = obs["screenshot"]
+ if isinstance(screenshot, bytes):
+ screenshot = base64.b64encode(screenshot).decode('utf-8')
+
+ # 获取宽度和高度
+ image = Image.open(io.BytesIO(obs["screenshot"]))
+ width, height = image.size
+ if self.resize_image:
+ resized_image = image.resize(
+ (
+ self.resized_image_width,
+ self.resized_image_height,
+ )
+ )
+ image_bytes_io = io.BytesIO() # 创建一个 BytesIO 对象
+ resized_image.save(image_bytes_io, format="PNG") # 将图像保存到 BytesIO 中,指定格式(如 PNG)
+ image_bytes = image_bytes_io.getvalue() # 获取字节数据
+ screenshot = base64.b64encode(image_bytes).decode('utf-8')
+
+ self.history_images.append(screenshot)
+
+ self.observations.append(
+ {"screenshot": screenshot, "accessibility_tree": None}
+ )
+
+ if len(self.history_images) > self.history_n:
+ self.history_images = self.history_images[-self.history_n:]
+
+ images = self.history_images
+
+ messages = [
+ {
+ "role": "system",
+ "content": self.system_prompt
+ },
+ {
+ "role": "system",
+ "content": '''## Function Definition\n\n- You have access to the following functions:\n{"type": "function", "name": "call_user", "parameters": {"type": "object", "properties": {"content": {"type": "string", "description": "Message or information displayed to the user to request their input, feedback, or guidance."}}, "required": []}, "description": "This function is used to interact with the user by displaying a message and requesting their input, feedback, or guidance."}\n{"type": "function", "name": "click", "parameters": {"type": "object", "properties": {"point": {"type": "string", "description": "Click coordinates. The format is: x y"}}, "required": ["point"]}, "description": "Mouse left single click action."}\n{"type": "function", "name": "drag", "parameters": {"type": "object", "properties": {"start_point": {"type": "string", "description": "Drag start point. The format is: x y"}, "end_point": {"type": "string", "description": "Drag end point. The format is: x y"}}, "required": ["start_point", "end_point"]}, "description": "Mouse left button drag action."}\n{"type": "function", "name": "finished", "parameters": {"type": "object", "properties": {"content": {"type": "string", "description": "Provide the final answer or response to complete the task."}}, "required": []}, "description": "This function is used to indicate the completion of a task by providing the final answer or response."}\n{"type": "function", "name": "hotkey", "parameters": {"type": "object", "properties": {"key": {"type": "string", "description": "Hotkeys you want to press. Split keys with a space and use lowercase."}}, "required": ["key"]}, "description": "Press hotkey."}\n{"type": "function", "function": {"name": "infeasible", "parameters": {"type": "object", "properties": {"content": {"type": "string", "description": "Message or information displayed to the user to explain why the current task is infeasible."}}, "required": ["content"]}, "description": "This function is used to indicate that the current task is infeasible thus agent ends the task."}\n{"type": "function", "name": "left_double", "parameters": {"type": "object", "properties": {"point": {"type": "string", "description": "Click coordinates. The format is: x y"}}, "required": ["point"]}, "description": "Mouse left double click action."}\n{"type": "function", "name": "right_single", "parameters": {"type": "object", "properties": {"point": {"type": "string", "description": "Click coordinates. The format is: x y"}}, "required": ["point"]}, "description": "Mouse right single click action."}\n{"type": "function", "name": "scroll", "parameters": {"type": "object", "properties": {"point": {"type": "string", "description": "Scroll start position. If not specified, default to execute on the current mouse position. The format is: x y"}, "direction": {"type": "string", "description": "Scroll direction.", "enum": ["up", "down", "left", "right"]}}, "required": ["direction", "point"]}, "description": "Scroll action."}\n{"type": "function", "name": "type", "parameters": {"type": "object", "properties": {"content": {"type": "string", "description": "Type content. If you want to submit your input, use \\n at the end of content."}}, "required": ["content"]}, "description": "Type content."}\n{"type": "function", "name": "wait", "parameters": {"type": "object", "properties": {"time": {"type": "integer", "description": "Wait time in seconds."}}, "required": []}, "description": "Wait for a while."}\n\n- To call a function, use the following structure without any suffix:\n\n reasoning process \nvalue_1\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n## Important Notes\n- Function calls must begin with .\n- All required parameters must be explicitly provided.\n\n## Additional Notes\n- You can execute multiple actions within a single tool call. For example:\nvalue_1\nThis is the value for the second parameter\nthat can span\nmultiple lines\nvalue_4\n- 当你判断任务请求是无法执行的时候,你应该调用Infeasible工具结束任务并解释原因。\n 判断标准:当一个请求符合以下任何一条标准时,应被归类为“无法执行”。\n 1. 技术/物理层面的矛盾: 指令本身包含逻辑上或物理上无法实现的要求。\n 2. 工具/功能错配: 指令要求在一个软件中执行另一个软件的功能,或者执行该软件根本不具备的功能。\n 3. 超出操作边界/范围: 指令要求执行的操作超出了当前用户会话、权限或应用程序的逻辑边界,涉及未告知的隐私信息或者未授权的操作。\n 4. 依赖隐性知识或外部条件: 任务的完成依赖于Agent无法获取的外部硬件、物理环境、未声明的插件/扩展、或特定的文件/数据。\n\n 输出指令:\n 如果请求被判断为“无法执行”,你应该向用户解释为什么这个任务超出了你的能力范围(例如,指出它需要直接操作某个硬件),并尽可能提供一个指导性的替代方案,让用户可以自己完成该任务。\n 你应该非常非常谨慎地使用Infeasible工具,因为它会直接结束任务并降低用户体验。所以非必要的时候,你不应该调用Infeasible工具,尽量以finish工具结束任务并向用户提示原因就好。'''
+ },
+ {
+ "role": "user",
+ "content": self.task_instruction
+ }
+ ]
+
+ image_num = 0
+ if len(self.history_responses) > 0:
+ for history_idx, history_response in enumerate(self.history_responses):
+ # send at most history_n images to the model
+ if history_idx + self.history_n > len(self.history_responses):
+ messages.append({
+ "role": "tool",
+ "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}],
+ "tool_call_id": "1"
+ })
+ image_num += 1
+
+ messages.append({
+ "role": "assistant",
+ "content": history_response
+ })
+ messages.append({
+ "role": "tool",
+ "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}],
+ "tool_call_id": "1"
+ })
+ image_num += 1
+ else:
+ messages.append({
+ "role": "tool",
+ "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}],
+ "tool_call_id": "1"
+ })
+ image_num += 1
+
+ messages = modify_conversations(messages)
+ try_times = 3
+ prediction = None
+ while True:
+ if try_times <= 0:
+ print(f"Reach max retry times to fetch response from client, as error flag.")
+ return prediction, ["FAIL"]
+ try:
+ logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}")
+ # json.dump(messages, open("debug_seed16.json", "w"), indent=4, ensure_ascii=False)
+ response = self.inference_func(messages)
+ content = response["content"]
+ if "reasoning_content" in response and response["reasoning_content"]:
+ reasoning_content = response["reasoning_content"]
+ prediction = f"{reasoning_content}{content}"
+ else:
+ prediction = content
+ break
+
+ except Exception as e:
+ print(f"Error when fetching response from client, with error:\n{e}")
+ prediction = None
+ try_times -= 1
+
+ self.history_responses.append(prediction)
+
+ try:
+ parsed_responses = parse_xml_action_v3(prediction, GUI_TOOL_SCHEMAS)
+ if "seed:tool_call" not in prediction and len(parsed_responses) == 0:
+ return prediction, ["DONE"]
+ if len(parsed_responses) == 0:
+ raise ValueError("Parsing action error")
+
+ except Exception as e:
+ print(f"Parsing action error: {prediction}, with error:\n{e}")
+ return prediction, ["FAIL"]
+
+ thoughts = prediction.split("")[0]
+ self.thoughts.append(thoughts)
+ actions = []
+ for parsed_xml_action in parsed_responses:
+ parsed_response = {
+ "action_type": parsed_xml_action["function"],
+ "action_inputs": parsed_xml_action["parameters"]
+ }
+
+ if parsed_response["action_type"] == FINISH_WORD:
+ self.actions.append(actions)
+ return prediction, ["DONE"]
+
+ elif parsed_response["action_type"] == WAIT_WORD:
+ self.actions.append(actions)
+ return prediction, ["WAIT"]
+
+ elif parsed_response["action_type"] == ENV_FAIL_WORD:
+ self.actions.append(actions)
+ return prediction, ["FAIL"]
+
+ elif parsed_response["action_type"] == CALL_USER:
+ self.actions.append(actions)
+ return prediction, ["FAIL"]
+
+ elif parsed_response["action_type"] == INFEASIBLE:
+ self.actions.append(actions)
+ return prediction, ["FAIL"]
+
+ pyautogui_code = parsing_response_to_pyautogui_code(
+ parsed_response,
+ height,
+ width,
+ self.input_swap
+ )
+ actions.append(pyautogui_code)
+
+ self.actions.append(actions)
+
+
+ return prediction, actions
+
\ No newline at end of file
diff --git a/run_multienv_seed16.py b/run_multienv_seed16.py
new file mode 100644
index 0000000..2c23425
--- /dev/null
+++ b/run_multienv_seed16.py
@@ -0,0 +1,540 @@
+from __future__ import annotations
+import argparse
+import datetime
+import json
+import logging
+import os
+import sys
+import signal
+import time
+from typing import List, Dict
+from multiprocessing import Process, Manager
+from multiprocessing import current_process
+import lib_run_single
+from desktop_env.desktop_env import DesktopEnv
+from mm_agents.seed16 import Seed16Agent
+import os
+
+
+# Global variables for signal handling
+active_environments = []
+processes = []
+is_terminating = False
+
+# load the environment variables from .env file
+if os.path.exists(".env"):
+ from dotenv import load_dotenv
+ load_dotenv()
+
+# Logger Configs {{{ #
+def config() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Run end-to-end evaluation on the benchmark"
+ )
+
+ # environment config
+ parser.add_argument("--path_to_vm", type=str, default=None)
+ parser.add_argument(
+ "--headless", action="store_true", help="Run in headless machine"
+ )
+ parser.add_argument(
+ "--action_space", type=str, default="pyautogui", help="Action type"
+ )
+ parser.add_argument(
+ "--observation_type",
+ choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
+ default="screenshot",
+ help="Observation type",
+ )
+ parser.add_argument("--sleep_after_execution", type=float, default=5.0)
+ parser.add_argument("--max_steps", type=int, default=100)
+
+ # evaluation config
+ parser.add_argument(
+ "--test_config_base_dir", type=str, default="evaluation_examples"
+ )
+
+ # lm config
+ parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428")
+ parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"])
+ parser.add_argument("--temperature", type=float, default=1.0)
+ parser.add_argument("--top_p", type=float, default=0.7)
+ parser.add_argument("--max_tokens", type=int, default=4096)
+ parser.add_argument("--use_thinking", action="store_true", default=False)
+
+ parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
+ parser.add_argument("--history_n", type=int, default=5, help="The max number of images in the history.")
+ parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.")
+
+ # example config
+ parser.add_argument("--domain", type=str, default="all")
+ parser.add_argument(
+ "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
+ )
+
+ # logging related
+ parser.add_argument("--result_dir", type=str, default="./results")
+ parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
+ parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
+ default='INFO', help="Set the logging level")
+ # aws config
+ parser.add_argument(
+ "--region", type=str, default="us-east-1", help="AWS region for the VM"
+ )
+ parser.add_argument(
+ "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
+ )
+ parser.add_argument(
+ "--client_password", type=str, default="", help="Client password"
+ )
+ parser.add_argument(
+ "--screen_width", type=int, default=1920, help="Screen width"
+ )
+ parser.add_argument(
+ "--screen_height", type=int, default=1080, help="Screen height"
+ )
+ args = parser.parse_args()
+ return args
+
+args = config() # Get command line arguments first
+
+logger = logging.getLogger()
+log_level = getattr(logging, args.log_level.upper())
+logger.setLevel(log_level)
+
+datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
+
+file_handler = logging.FileHandler(
+ os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
+)
+debug_handler = logging.FileHandler(
+ os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
+)
+stdout_handler = logging.StreamHandler(sys.stdout)
+
+file_handler.setLevel(logging.INFO)
+debug_handler.setLevel(logging.DEBUG)
+stdout_handler.setLevel(log_level)
+
+formatter = logging.Formatter(
+ fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
+)
+file_handler.setFormatter(formatter)
+debug_handler.setFormatter(formatter)
+stdout_handler.setFormatter(formatter)
+
+stdout_handler.addFilter(logging.Filter("desktopenv"))
+
+logger.addHandler(file_handler)
+logger.addHandler(debug_handler)
+logger.addHandler(stdout_handler)
+# }}} Logger Configs #
+
+logger = logging.getLogger("desktopenv.experiment")
+
+
+def distribute_tasks(test_all_meta: dict) -> List[tuple]:
+ all_tasks = []
+ for domain, examples in test_all_meta.items():
+ for example_id in examples:
+ all_tasks.append((domain, example_id))
+ return all_tasks
+
+
+def process_signal_handler(signum, frame, env_idx):
+ """Signal handler for child processes to gracefully shut down their environments."""
+ logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
+
+ # Get the active_environments from the caller's frame
+ local_vars = frame.f_locals
+ active_environments = local_vars.get('active_environments', [])
+
+ # Close environment in the current process context
+ for env in active_environments:
+ if env is not None:
+ try:
+ logger.info(f"Process {env_idx + 1} closing environment...")
+ env.close()
+ logger.info(f"Process {env_idx + 1} environment closed successfully")
+ except Exception as e:
+ logger.error(f"Process {env_idx + 1} error closing environment: {e}")
+
+ logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
+ sys.exit(0)
+
+def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
+ active_environments = []
+ env = None
+ try:
+ from desktop_env.providers.aws.manager import IMAGE_ID_MAP
+ REGION = args.region
+ screen_size = (args.screen_width, args.screen_height)
+ ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
+ env = DesktopEnv(
+ path_to_vm=args.path_to_vm,
+ action_space=args.action_space,
+ provider_name=args.provider_name,
+ region=REGION,
+ snapshot_name=ami_id,
+ screen_size=screen_size,
+ headless=args.headless,
+ os_type="Ubuntu",
+ require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"],
+ enable_proxy=True,
+ client_password=args.client_password
+ )
+ active_environments.append(env)
+ agent = Seed16Agent(
+ model=args.model,
+ model_type=args.model_type,
+ max_tokens=args.max_tokens,
+ top_p=args.top_p,
+ temperature=args.temperature,
+ max_trajectory_length=args.max_trajectory_length,
+ history_n=args.history_n,
+ use_thinking=args.use_thinking,
+ )
+ logger.info(f"Process {current_process().name} started.")
+ while True:
+ try:
+ item = task_queue.get(timeout=5)
+ except Exception:
+ break
+ domain, example_id = item
+ try:
+ config_file = os.path.join(
+ args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
+ )
+ with open(config_file, "r", encoding="utf-8") as f:
+ example = json.load(f)
+ logger.info(f"[{current_process().name}][Domain]: {domain}")
+ logger.info(f"[{current_process().name}][Example ID]: {example_id}")
+ logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
+ example_result_dir = os.path.join(
+ args.result_dir,
+ args.action_space,
+ args.observation_type,
+ args.model,
+ domain,
+ example_id,
+ )
+ os.makedirs(example_result_dir, exist_ok=True)
+ try:
+ lib_run_single.run_single_example(
+ agent,
+ env,
+ example,
+ args.max_steps,
+ example["instruction"],
+ args,
+ example_result_dir,
+ shared_scores,
+ )
+ except Exception as e:
+ import traceback
+ logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
+ logger.error(traceback.format_exc())
+ try:
+ env.controller.end_recording(
+ os.path.join(example_result_dir, "recording.mp4")
+ )
+ except Exception as rec_e:
+ logger.error(f"Failed to end recording: {rec_e}")
+ with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
+ f.write(
+ json.dumps(
+ {"Error": f"{domain}/{example_id} - {e}"}
+ )
+ )
+ f.write("\n")
+ except Exception as e:
+ logger.error(f"Task-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
+ except Exception as e:
+ logger.error(f"Process-level error in {current_process().name}: {e}")
+ import traceback
+ logger.error(traceback.format_exc())
+ finally:
+ logger.info(f"{current_process().name} cleaning up environment...")
+ try:
+ if env:
+ env.close()
+ logger.info(f"{current_process().name} environment closed successfully")
+ except Exception as e:
+ logger.error(f"{current_process().name} error during environment cleanup: {e}")
+
+
+
+def signal_handler(signum, frame):
+ """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
+ global is_terminating, active_environments, processes
+
+ # Avoid duplicate handling
+ if is_terminating:
+ return
+
+ is_terminating = True
+ logger.info(f"Received signal {signum}. Gracefully shutting down...")
+
+ # Close all registered environments in the main process
+ for env in active_environments:
+ try:
+ logger.info(f"Closing environment...")
+ env.close()
+ logger.info(f"Environment closed successfully")
+ except Exception as e:
+ logger.error(f"Error closing environment: {e}")
+
+ # Send termination signal to all child processes first
+ for p in processes:
+ if p.is_alive():
+ try:
+ logger.info(f"Sending termination signal to process {p.name}...")
+ p.terminate()
+ except Exception as e:
+ logger.error(f"Error sending termination signal to process: {e}")
+
+ # Allow a short time for processes to handle their own cleanup
+ time.sleep(1)
+
+ # Forcefully terminate any processes that didn't exit
+ for p in processes:
+ if p.is_alive():
+ try:
+ logger.info(f"Forcefully terminating process {p.name}...")
+ import signal as sig
+ os.kill(p.pid, sig.SIGKILL)
+ except Exception as e:
+ logger.error(f"Error forcefully terminating process: {e}")
+
+ logger.info("Shutdown complete. Exiting.")
+ sys.exit(0)
+
+
+def test(args: argparse.Namespace, test_all_meta: dict) -> None:
+ global processes
+ logger.info("Args: %s", args)
+ all_tasks = distribute_tasks(test_all_meta)
+ logger.info(f"Total tasks: {len(all_tasks)}")
+ with Manager() as manager:
+ shared_scores = manager.list()
+ task_queue = manager.Queue()
+ for item in all_tasks:
+ task_queue.put(item)
+ num_envs = args.num_envs
+ processes = []
+ for i in range(num_envs):
+ p = Process(
+ target=run_env_tasks,
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-{i+1}"
+ )
+ p.daemon = True
+ p.start()
+ processes.append(p)
+ logger.info(f"Started process {p.name} with PID {p.pid}")
+ try:
+ while True:
+ alive_count = 0
+ for idx, p in enumerate(processes):
+ if not p.is_alive():
+ logger.warning(f"Process {p.name} died, restarting...")
+ new_p = Process(
+ target=run_env_tasks,
+ args=(task_queue, args, shared_scores),
+ name=f"EnvProcess-Restart-{idx+1}"
+ )
+ new_p.daemon = True
+ new_p.start()
+ processes[idx] = new_p
+ logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
+ else:
+ alive_count += 1
+ if task_queue.empty():
+ logger.info("All tasks finished.")
+ break
+ if alive_count == 0:
+ logger.error("All processes died, exiting.")
+ break
+ time.sleep(5)
+ for p in processes:
+ p.join()
+ except KeyboardInterrupt:
+ logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
+ raise
+ except Exception as e:
+ logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
+ for p in processes:
+ if p.is_alive():
+ try:
+ logger.info(f"Terminating process {p.name} due to error...")
+ p.terminate()
+ except Exception as term_e:
+ logger.error(f"Error terminating process {p.name}: {term_e}")
+ raise
+ scores = list(shared_scores)
+ logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
+
+
+def get_unfinished(
+ action_space, use_model, observation_type, result_dir, total_file_json
+):
+ target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
+
+ if not os.path.exists(target_dir):
+ return total_file_json
+
+ finished = {}
+ for domain in os.listdir(target_dir):
+ finished[domain] = []
+ domain_path = os.path.join(target_dir, domain)
+ if os.path.isdir(domain_path):
+ for example_id in os.listdir(domain_path):
+ if example_id == "onboard":
+ continue
+ example_path = os.path.join(domain_path, example_id)
+ if os.path.isdir(example_path):
+ if "result.txt" not in os.listdir(example_path):
+ # empty all files under example_id
+ for file in os.listdir(example_path):
+ os.remove(os.path.join(example_path, file))
+ else:
+ finished[domain].append(example_id)
+
+ if not finished:
+ return total_file_json
+
+ for domain, examples in finished.items():
+ if domain in total_file_json:
+ total_file_json[domain] = [
+ x for x in total_file_json[domain] if x not in examples
+ ]
+
+ return total_file_json
+
+
+def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
+ target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
+ if not os.path.exists(target_dir):
+ print("New experiment, no result yet.")
+ return None
+
+ all_result = []
+
+ for domain in os.listdir(target_dir):
+ domain_path = os.path.join(target_dir, domain)
+ if os.path.isdir(domain_path):
+ for example_id in os.listdir(domain_path):
+ example_path = os.path.join(domain_path, example_id)
+ if os.path.isdir(example_path):
+ if "result.txt" in os.listdir(example_path):
+ # empty all files under example_id
+ try:
+ all_result.append(
+ float(
+ open(
+ os.path.join(example_path, "result.txt"), "r"
+ ).read()
+ )
+ )
+ except:
+ all_result.append(0.0)
+
+ if not all_result:
+ print("New experiment, no result yet.")
+ return None
+ else:
+ print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
+ return all_result
+
+
+if __name__ == "__main__":
+ ####### The complete version of the list of examples #######
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ # Register signal handlers for graceful termination
+ signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
+ signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
+
+ try:
+ args = config()
+
+ # save args to json in result_dir/action_space/observation_type/model/args.json
+ path_to_args = os.path.join(
+ args.result_dir,
+ args.action_space,
+ args.observation_type,
+ args.model,
+ "args.json",
+ )
+ os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
+ with open(path_to_args, "w", encoding="utf-8") as f:
+ json.dump(vars(args), f, indent=4)
+
+ with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
+ test_all_meta = json.load(f)
+
+ if args.domain != "all":
+ test_all_meta = {args.domain: test_all_meta[args.domain]}
+
+ test_file_list = get_unfinished(
+ args.action_space,
+ args.model,
+ args.observation_type,
+ args.result_dir,
+ test_all_meta,
+ )
+ left_info = ""
+ for domain in test_file_list:
+ left_info += f"{domain}: {len(test_file_list[domain])}\n"
+ logger.info(f"Left tasks:\n{left_info}")
+
+ get_result(
+ args.action_space,
+ args.model,
+ args.observation_type,
+ args.result_dir,
+ test_all_meta,
+ )
+ test(args, test_file_list)
+ except KeyboardInterrupt:
+ logger.info("Main process received KeyboardInterrupt.")
+ # Signal handler will take care of cleanup
+ except Exception as e:
+ logger.error(f"Unexpected error in main process: {e}", exc_info=True)
+ # Also trigger cleanup for unhandled exceptions
+ signal_handler(signal.SIGTERM, None)
+ finally:
+ # Final cleanup in case any environments or processes remain
+ logger.info("Main process final cleanup...")
+ for env in active_environments:
+ if env is not None:
+ try:
+ logger.info(f"Closing environment in final cleanup...")
+ env.close()
+ logger.info(f"Environment closed successfully in final cleanup")
+ except Exception as e:
+ logger.error(f"Error during final environment cleanup: {e}")
+
+ # First try gentle termination
+ for p in processes:
+ if p is not None and p.is_alive():
+ try:
+ logger.info(f"Terminating process {p.name}...")
+ p.terminate()
+ except Exception as e:
+ logger.error(f"Error terminating process: {e}")
+
+ # Wait a moment for processes to terminate
+ time.sleep(1)
+
+ # Then force kill if needed
+ for p in processes:
+ if p is not None and p.is_alive():
+ try:
+ logger.info(f"Force killing process {p.name}...")
+ os.kill(p.pid, signal.SIGKILL)
+ logger.info(f"Process {p.name} force killed")
+ except Exception as e:
+ logger.error(f"Error force killing process: {e}")