feat/dart_gui (#371)

This commit is contained in:
Pengxiang-Li
2025-11-07 21:50:01 +08:00
committed by GitHub
parent 6d43dbc532
commit 00b6468eb7
8 changed files with 2499 additions and 4 deletions

1
.gitignore vendored
View File

@@ -205,3 +205,4 @@ draft/
manual_examine.py
run_human_examine.sh
quick_start.py
result_multi_apps_pengxiang_transformers12

View File

@@ -0,0 +1,161 @@
COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
UITARS_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
"""
UITARS_CALL_USR_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
"""
UITARS_NORMAL_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
"""
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## User Instruction
{instruction}
"""
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
FAILURE_INDICATORS = [
# Direct inability expressions
"无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法",
# Regret/apology expressions
"遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起",
# Not supported/available
"不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在",
# Cannot access/handle
"无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也",
# Not possible/feasible
"不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置",
# System limitations
"超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题",
# Refusal indicators
"拒绝", "不允许", "禁止", "不合适", "不恰当",
# Trying Restart
"从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启",
]

View File

@@ -0,0 +1,202 @@
import asyncio
from typing import List, Optional, Union, Dict, Any
import json
import os
import hashlib
from pathlib import Path
from omegaconf import DictConfig
from dataclasses import dataclass, asdict
import copy
import logging
import random
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
from log_config import setup_logging
# 设置统一的日志系统
setup_logging()
logger = logging.getLogger(__name__)
class TaskLoader:
def __init__(self, task_cfg: DictConfig, storage_root):
self.task_file = Path(task_cfg.task_file)
#self.task_root = Path(task_cfg.task_root)
self.osworld_root = Path(task_cfg.osworld_root)
self._latest_sha: Optional[str] = None
self.storage_root = storage_root
self.resume = task_cfg.resume
def poll_for_tasks(self) -> List[Dict]:
"""find new tasks json file
return list of TaskInfo dict if there is new json
else return []
"""
self._maybe_refresh_dataset()
tasks_list = [task.to_dict() for task in self._tasks]
random.shuffle(tasks_list)
return tasks_list
def _maybe_refresh_dataset_bak(self):
# check new json
latest_json = self._find_latest_json()
if latest_json is None:
return False # no json file
sha = self._calc_sha1(latest_json)
if sha == self._latest_sha:
return False # no change
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
self._latest_sha = sha
logger.info(f"当前任务文件: {str(latest_json)}")
logger.info(f"任务总数: {len(raw_tasks)}")
return True
def _maybe_refresh_dataset(self):
latest_json = self.task_file
print("Current tasks file: ", str(latest_json))
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
if self.resume:
# 过滤已完成或类型不匹配的任务
filtered_tasks = []
storage_root = Path(self.storage_root)
for raw in raw_tasks:
task_id = str(raw["task_id"])
task_type_expected = raw["task_type"]
# 找到所有以 task_id 开头的子目录(允许有多个版本)
candidate_dirs = [
d for d in storage_root.iterdir()
if d.is_dir() and d.name.startswith(task_id)
]
# 默认认为任务未完成
task_finished = False
for d in candidate_dirs:
cfg_path = d / "task_config.json"
if not cfg_path.exists():
print("找不到config文件")
continue
try:
with cfg_path.open("r", encoding="utf-8") as cf:
cfg = json.load(cf)
except Exception:
print("配置损坏,忽略此目录")
continue
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
if cfg.get("raw", {}).get("task_type") != task_type_expected:
continue
# 3.2 task_type 相同,检查 reward.txt
if (d / "reward.txt").exists():
task_finished = True
break # 已找到完成记录,无需再看其他目录
if not task_finished:
filtered_tasks.append(raw)
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
else:
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
print(f"Total number of tasks: {len(raw_tasks)}")
return True
def _find_latest_json(self) -> Optional[Path]:
files = list(self.task_root.glob("*.json"))
return max(files, key=lambda p: p.stat().st_mtime) if files else None
@staticmethod
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
h = hashlib.sha1()
with fp.open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
h.update(chunk)
return h.hexdigest()
@dataclass
class TaskInfo:
messages: List
instruction: str
task_config: Dict
def to_dict(self):
return asdict(self)
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
task_type = raw["task_type"]
task_id = raw["task_id"]
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
with open(task_path) as f:
task_data = json.load(f)
task_data["raw"] = {
"task_type": task_type,
"task_id": task_id
}
instruction = task_data["instruction"]
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
plan = task_data["human-ground-truth"]["single-action"]
plan_text = "\n".join(plan)
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language="English"
)}
]
}
]
return TaskInfo(
messages = messages,
instruction = instruction,
task_config = task_data
)

511
mm_agents/dart_gui/utils.py Normal file
View File

@@ -0,0 +1,511 @@
import ast
import base64
import logging
import math
import re
import xml.etree.ElementTree as ET
from io import BytesIO
from typing import Dict, List
import numpy as np
import openai
from openai import OpenAI
from PIL import Image
from requests.exceptions import SSLError
from mm_agents.dart_gui.prompts import FAILURE_INDICATORS
# 设置日志系统
logger = logging.getLogger(__name__)
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
pure_text_settings = ["a11y_tree"]
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
# 定义一个函数来解析每个 action
def parse_action(action_str):
try:
# 解析字符串为 AST 节点
node = ast.parse(action_str, mode='eval')
# 确保节点是一个表达式
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
# 获取表达式的主体
call = node.body
# 确保主体是一个函数调用
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# 获取函数名
if isinstance(call.func, ast.Name):
func_name = call.func.id
elif isinstance(call.func, ast.Attribute):
func_name = call.func.attr
else:
func_name = None
# 获取关键字参数
kwargs = {}
for kw in call.keywords:
key = kw.arg
# 处理不同类型的值,这里假设都是常量
if isinstance(kw.value, ast.Constant):
value = kw.value.value
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
value = kw.value.s
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
except Exception as e:
logger.error(f"Failed to parse action '{action_str}': {e}")
return None
def escape_single_quotes(text):
# 匹配未转义的单引号(不匹配 \\'
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def linear_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if width * height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if width * height < min_pixels:
resize_factor = math.sqrt(min_pixels / (width * height))
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
return height, width
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
text = text.strip()
if model_type == "qwen25vl":
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
# 正则表达式匹配 Action 字符串
if text.startswith("Thought:"):
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
elif text.startswith("Reflection:"):
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Reflection: "
elif text.startswith("Action_Summary:"):
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Action_Summary: "
else:
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
reflection, thought = None, None
thought_match = re.search(thought_pattern, text, re.DOTALL)
if thought_match:
if len(thought_match.groups()) == 1:
thought = thought_match.group(1).strip()
elif len(thought_match.groups()) == 2:
thought = thought_match.group(2).strip()
reflection = thought_match.group(1).strip()
assert "Action:" in text
action_str = text.split("Action:")[-1]
tmp_all_action = action_str.split("\n\n")
all_action = []
for action_str in tmp_all_action:
if "type(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
if "finished(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "finished(content='" + action_str + "')"
all_action.append(action_str)
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
actions = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
logger.error(f"Action can't parse: {raw_str}")
# raise ValueError(f"Action can't parse: {raw_str}")
continue
action_type = action_instance["function"]
params = action_instance["args"]
# import pdb; pdb.set_trace()
action_inputs = {}
for param_name, param in params.items():
if param == "": continue
param = param.lstrip() # 去掉引号和多余的空格
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
action_inputs[param_name.strip()] = param
if "start_box" in param_name or "end_box" in param_name:
ori_box = param
# Remove parentheses and split the string by commas
numbers = ori_box.replace("(", "").replace(")", "").split(",")
# Convert to float and scale by 1000
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
if model_type == "qwen25vl":
float_numbers = []
for num_idx, num in enumerate(numbers):
num = float(num)
if (num_idx + 1) % 2 == 0:
float_numbers.append(float(num/smart_resize_height))
else:
float_numbers.append(float(num/smart_resize_width))
else:
float_numbers = [float(num) / factor for num in numbers]
if len(float_numbers) == 2:
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
# import pdb; pdb.set_trace()
actions.append(
{
"reflection": reflection,
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
})
return actions
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
'''
将M模型的输出解析为OSWorld中的action生成pyautogui代码字符串
参数:
response: 包含模型输出的字典,结构类似于:
{
"action_type": "hotkey",
"action_inputs": {
"hotkey": "v ctrl",
"start_box": None,
"end_box": None
}
}
返回:
生成的pyautogui代码字符串
'''
pyautogui_code = "import pyautogui\nimport time\n"
if isinstance(responses, dict):
responses = [responses]
for response_id, response in enumerate(responses):
if "observation" in response:
observation = response["observation"]
else:
observation = ""
if "thought" in response:
thought = response["thought"]
else:
thought = ""
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += "\ntime.sleep(1)\n"
action_dict = response
response_text = action_dict.get("text", "")
action_type = action_dict.get("action_type")
action_inputs = action_dict.get("action_inputs", {})
if action_type == "hotkey":
# Parsing hotkey action
if "key" in action_inputs:
hotkey = action_inputs.get("key", "")
else:
hotkey = action_inputs.get("hotkey", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
if hotkey:
# Handle other hotkeys
keys = hotkey.split() # Split the keys by space
convert_keys = []
for key in keys:
if key == "space":
key = ' '
convert_keys.append(key)
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
elif action_type == "press":
# Parsing press action
if "key" in action_inputs:
key_to_press = action_inputs.get("key", "")
else:
key_to_press = action_inputs.get("press", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
elif hotkey == "space":
hotkey = " "
if key_to_press:
# Simulate pressing a single key
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
elif action_type == "keyup":
key_to_up = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
elif action_type == "keydown":
key_to_down = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
elif action_type == "type":
# Parsing typing action using clipboard
content = action_inputs.get("content", "")
content = escape_single_quotes(content)
stripped_content = content
if content.endswith("\n") or content.endswith("\\n"):
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
if content:
if input_swap:
pyautogui_code += "\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
sx = round(float((x1 + x2) / 2) * image_width, 3)
sy = round(float((y1 + y2) / 2) * image_height, 3)
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += (
f"\npyautogui.moveTo({sx}, {sy})\n"
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
elif action_type == "scroll":
# Parsing scroll action
start_box = action_inputs.get("start_box")
if start_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
# # 先点对应区域,再滚动
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
else:
x = None
y = None
direction = action_inputs.get("direction", "")
if x == None:
if "up" in direction.lower():
pyautogui_code += "\npyautogui.scroll(5)"
elif "down" in direction.lower():
pyautogui_code += "\npyautogui.scroll(-5)"
else:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
# Parsing mouse click actions
start_box = action_inputs.get("start_box")
start_box = str(start_box)
if start_box:
start_box = eval(start_box)
if start_box is None:
logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}")
if len(start_box) == 4:
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
elif len(start_box) == 2:
x1, y1 = start_box
x2 = x1
y2 = y1
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
if action_type == "left_single" or action_type == "click":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
elif action_type == "left_double":
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
elif action_type == "right_single":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
elif action_type == "hover":
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]:
pyautogui_code = "DONE"
print(f"FINISHED:response_text: {response_text}")
print(f"FINISHED:response: {str(response)}")
for failure_indicator in FAILURE_INDICATORS:
if failure_indicator in response_text:
pyautogui_code = "FAIL"
break
elif action_type in ["wait"]:
pyautogui_code = "WAIT"
elif action_type in ["call_user"]:
pyautogui_code = "FAIL"
else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
return pyautogui_code
def add_box_token(input_string):
# Step 1: Split the string into individual actions
if "Action: " in input_string and "start_box=" in input_string:
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
# Step 2: Extract coordinates (start_box or end_box) using regex
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action # Start with the original action
for coord_type, x, y in coordinates:
# Convert x and y to integers
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
processed_actions.append(updated_action)
# Step 5: Reconstruct the final string
final_string = suffix + "\n\n".join(processed_actions)
else:
final_string = input_string
# print(f"Input string: {input_string}")
# print(f"Final string: {final_string}")
return [{"type": "text", "text": final_string}]
def pil_to_base64(image):
"""Convert PIL Image or bytes to base64 string"""
if isinstance(image, bytes):
# If it's already bytes, just encode to base64
return base64.b64encode(image).decode("utf-8")
else:
# If it's a PIL Image, convert it
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")

686
mm_agents/dart_gui_agent.py Normal file
View File

@@ -0,0 +1,686 @@
"""
Dart Agent - Custom agent for GUI automation using Dart models
Based on UITARSAgent structure but using Dart-specific utilities and prompts
"""
import ast
import base64
import logging
import math
import os
import re
import time
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
from openai import OpenAI
import backoff
import openai
import requests
from requests.exceptions import SSLError
from google.api_core.exceptions import (
BadRequest,
InternalServerError,
InvalidArgument,
ResourceExhausted,
)
# Import Dart-specific utilities and prompts
from mm_agents.dart_gui.utils import (
pil_to_base64,
parse_action_to_structure_output,
parsing_response_to_pyautogui_code,
parse_action,
escape_single_quotes,
round_by_factor,
ceil_by_factor,
floor_by_factor,
linear_resize,
smart_resize,
add_box_token,
IMAGE_FACTOR,
MIN_PIXELS,
MAX_PIXELS,
MAX_RATIO,
FINISH_WORD,
WAIT_WORD,
ENV_FAIL_WORD,
CALL_USER
)
from mm_agents.dart_gui.prompts import (
COMPUTER_USE_PROMPT,
COMPUTER_USE_PROMPT_WITH_CALL_USER,
UITARS_ACTION_SPACE,
UITARS_CALL_USR_ACTION_SPACE,
UITARS_USR_PROMPT_THOUGHT,
UITARS_USR_PROMPT_NOTHOUGHT
)
logger = logging.getLogger("desktopenv.agent")
class DartAgent:
def __init__(
self,
model: str,
runtime_conf: Dict,
platform="ubuntu",
max_tokens=1000,
top_p=0.9,
top_k=1.0,
temperature=0.0,
action_space="pyautogui",
observation_type="screenshot",
max_trajectory_length=50,
model_type="qwen25vl",
**kwargs
):
self.model = model
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.model_type = model_type
self.runtime_conf = runtime_conf
# Extract runtime configuration parameters
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
self.top_p = self.runtime_conf.get("top_p", top_p)
self.top_k = self.runtime_conf.get("top_k", top_k)
self.temperature = self.runtime_conf.get("temperature", temperature)
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
self.input_swap = self.runtime_conf.get("input_swap", False)
self.language = self.runtime_conf.get("language", "English")
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
self.history_n = self.runtime_conf.get("history_n", 5)
# Dart specific configurations
self.max_images = self.runtime_conf.get("max_images", 5)
self.max_texts = self.runtime_conf.get("max_texts", 35)
# Initialize OpenAI client - use Dart API if provided
dart_api_key = self.runtime_conf.get("dart_api_key", "")
dart_base_url = self.runtime_conf.get("dart_base_url", "")
if dart_base_url:
# 检查是否为直接的生成端点(包含 /generate
if '/generate' in dart_base_url:
# 直接使用提供的 URL不添加 /v1
logger.info(f"使用直接生成端点: {dart_base_url}")
self.dart_direct_url = dart_base_url
self.vlm = None # 不使用 OpenAI 客户端
else:
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
if not dart_base_url.endswith('/v1'):
dart_base_url = dart_base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=dart_base_url,
api_key=dart_api_key,
)
self.dart_direct_url = None
else:
# Fallback to environment variables
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
if base_url:
if '/generate' in base_url:
# 直接生成端点
self.dart_direct_url = base_url
self.vlm = None
else:
if not base_url.endswith('/v1'):
base_url = base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=base_url,
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
)
self.dart_direct_url = None
else:
self.vlm = None
self.dart_direct_url = None
# Initialize trajectory storage - similar to trajectory_runner.py
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Message handling similar to trajectory_runner.py
self.base_messages = [] # for model client (with base64 images)
self.base_messages_for_save = [] # for storage (with file paths)
self.prompt_dialogue = [] # for model client
self.save_dialogue = [] # for storage
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
self.image_refs = [] # record image position
# All image paths storage - to keep track of all images even when trimmed
self.all_image_paths = []
# Current screenshot file path for proper saving
self.current_screenshot_path = None
# Configure prompt and action space based on mode
if self.infer_mode == "dart_mode":
self.prompt_action_space = UITARS_ACTION_SPACE
self.prompt_template = COMPUTER_USE_PROMPT
else:
# For qwen2vl_user mode
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
if self.prompt_style == "qwen2vl_user":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
else:
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
self.action_parse_res_factor = 1000
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
def reset(self, runtime_logger=None):
"""Reset the agent state"""
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Reset message handling
self.base_messages = []
self.base_messages_for_save = []
self.prompt_dialogue = []
self.save_dialogue = []
self.save_dialogue_full = []
self.image_refs = []
self.all_image_paths = []
self.current_screenshot_path = None
logger.info("DartAgent reset")
def set_base_messages(self, instruction: str):
"""Initialize base messages similar to task_loader.py"""
system_prompt = COMPUTER_USE_PROMPT
self.base_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language=self.language
)
}
]
}
]
# Copy for save version
from copy import deepcopy
self.base_messages_for_save = deepcopy(self.base_messages)
def set_current_screenshot_path(self, screenshot_path: str):
"""Set the current screenshot file path for proper saving"""
self.current_screenshot_path = screenshot_path
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
) -> tuple:
"""
Predict the next action(s) based on the current observation.
Returns: (response_text, actions_list)
"""
# Initialize base messages if not set
if not self.base_messages:
self.set_base_messages(instruction)
# Store current observation
self._add_observation(obs)
# For first step, set the first frame
if len(self.observations) == 1:
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
else:
# For subsequent steps, add the new image to dialogue
# This represents the result of the previous action
self._add_image(obs["screenshot"], self.current_screenshot_path)
# Build prompt messages (base_messages + prompt_dialogue)
messages = self._build_messages()
# Call model to get response
prediction = self._call_model(messages)
if prediction is None:
return "client error", ["DONE"]
# Store response and parse actions
self._add_text(prediction)
# Parse response to actions
try:
image_size = self._get_current_image_size()
actions = self._parse_and_convert_actions(prediction, image_size)
# Check for terminal actions
terminal_action = self._check_terminal_actions(actions)
if terminal_action:
self.actions.append(actions)
return prediction, [terminal_action]
except Exception as e:
logger.error(f"Parsing action error: {prediction}, error: {e}")
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
self.actions.append(actions)
# Check max steps
if len(self.history_responses) >= self.max_trajectory_length:
actions = ["FAIL"]
return prediction, actions
@backoff.on_exception(
backoff.constant,
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
),
interval=30,
max_tries=10,
)
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
"""Predict with backoff for rate limiting and temporary errors"""
return self.predict(instruction, obs, last_action_after_obs)
def get_trajectory(self) -> List[Dict]:
"""Get the current trajectory for saving"""
trajectory = []
for i in range(len(self.observations)):
trajectory.append({
"observation": self.observations[i],
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
"action": self.actions[i] if i < len(self.actions) else []
})
return trajectory
def get_full_messages(self) -> List[Dict]:
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
full_messages = []
# Add base messages (system prompt and initial user message)
full_messages.extend(self.base_messages_for_save)
# Add dialogue messages (user images + assistant responses) with all images
full_messages.extend(self.save_dialogue_full)
return full_messages
def get_all_image_paths(self) -> List[str]:
"""Get all image paths that have been used throughout the conversation"""
return self.all_image_paths.copy()
# ========== Private Methods ==========
def _validate_trajectory(self):
"""Validate trajectory consistency"""
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
def _add_observation(self, obs: Dict):
"""Process observation and add to history"""
# Store observation
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = obs["screenshot"]
try:
# Handle accessibility tree if needed
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
# For now, we'll skip accessibility tree processing in Dart mode
linearized_accessibility_tree = None
except:
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree,
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type)
def _build_messages(self) -> List[Dict]:
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
return self.base_messages + self.prompt_dialogue
def _call_model(self, messages: List[Dict]) -> str:
"""Call model with retry logic"""
try_times = 3
while try_times > 0:
try:
# 如果使用直接生成端点
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
prediction = self._call_direct_generate_endpoint(messages)
else:
# 使用标准 OpenAI 客户端
response = self.vlm.chat.completions.create(
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
prediction = response.choices[0].message.content
logger.info(f"Model response: {prediction}")
return prediction
except Exception as e:
logger.error(f"Error when fetching response from client: {e}")
try_times -= 1
if try_times <= 0:
logger.error("Reach max retry times to fetch response from client")
return None
return None
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
"""直接调用生成端点"""
try:
# 构建请求数据
payload = {
"messages": messages,
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": 1
}
# 添加 API key 到 headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
}
# 重试机制最多重试3次每次推理60秒
max_retries = 3
response = None
for attempt in range(max_retries):
try:
logger.info(f"尝试第 {attempt + 1} 次请求...")
response = requests.post(
self.dart_direct_url,
json=payload,
headers=headers,
timeout=60
)
response.raise_for_status()
break # 成功则跳出重试循环
except Exception as e:
logger.warning(f"{attempt + 1} 次请求失败: {e}")
if attempt == max_retries - 1: # 最后一次重试失败
logger.error(f"所有 {max_retries} 次重试都失败了")
raise e
else:
logger.info(f"等待后重试...")
import time
time.sleep(2) # 等待2秒后重试
# 解析响应
result = response.json()
# 尝试多种可能的响应格式
if 'choices' in result and len(result['choices']) > 0:
# OpenAI 兼容格式
return result['choices'][0]['message']['content']
elif 'response' in result:
# 简单的 response 字段
return result['response']
elif 'text' in result:
# text 字段
return result['text']
elif 'content' in result:
# content 字段
return result['content']
else:
# 如果找不到标准字段,返回整个响应的字符串
logger.warning(f"未知的响应格式: {result}")
return str(result)
except Exception as e:
logger.error(f"直接端点调用失败: {e}")
raise e
def _add_text(self, assistant_txt: str):
"""Add text response to history - similar to trajectory_runner.py"""
self.history_responses.append(assistant_txt)
self.thoughts.append(assistant_txt)
# Add to dialogue similar to trajectory_runner._add_text
msg = {
"role": "assistant",
"content": add_box_token(assistant_txt)
}
self.prompt_dialogue.append(msg)
self.save_dialogue.append(msg)
self.save_dialogue_full.append(msg)
self._trim()
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
self.base_messages[1]["content"].append(
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
}
)
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
if frame_path:
first_frame_path = frame_path
elif self.current_screenshot_path:
first_frame_path = self.current_screenshot_path
else:
first_frame_path = "first_frame.png"
# Store in all_image_paths
self.all_image_paths.append(first_frame_path)
self.base_messages_for_save[1]["content"].append(
{
"type": "image_url",
"image_url": first_frame_path
}
)
self.image_refs.append(
{"source": "base", "msg_idx": 1,
"content_idx": len(self.base_messages[1]["content"]) - 1}
)
def _add_image(self, img_bytes: bytes, frame_path: str = None):
"""Add image to dialogue - similar to trajectory_runner._add_image"""
self.prompt_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
}]
})
# Use actual frame path if provided, otherwise use current_screenshot_path
if frame_path:
image_url = frame_path
elif self.current_screenshot_path:
image_url = self.current_screenshot_path
else:
# Fallback to a placeholder - this should rarely happen in practice
image_url = f"frame_{len(self.save_dialogue)}.png"
# Store in all_image_paths for complete record
self.all_image_paths.append(image_url)
# Add to save_dialogue (trimmed version)
self.save_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
# Add to save_dialogue_full (complete version - never trimmed)
self.save_dialogue_full.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
self.image_refs.append(
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
"content_idx": None}
)
self._trim()
def _trim(self):
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
img_cnt = len(self.image_refs)
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
while img_cnt > self.max_images or txt_cnt > self.max_texts:
# 图片超限:最早一张
if img_cnt > self.max_images:
ref = self.image_refs.pop(0)
if ref["source"] == "base":
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
else: # dialogue 图
self._remove_dialogue_msg(ref["msg_idx"])
img_cnt -= 1
continue
# 文本超限:最早 assistant 文本
if txt_cnt > self.max_texts:
for i, m in enumerate(self.prompt_dialogue):
if m["role"] == "assistant":
self._remove_dialogue_msg(i)
txt_cnt -= 1
break
def _remove_dialogue_msg(self, idx: int):
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
self.prompt_dialogue.pop(idx)
self.save_dialogue.pop(idx)
# Note: save_dialogue_full is never trimmed, so we don't remove from it
# 更新 image_refs
self.image_refs = [
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
else None # 同一条被删掉的图引用直接丢弃
for r in self.image_refs
]
self.image_refs = [
(
{**r, "msg_idx": r["msg_idx"] - 1}
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
else r
)
for r in self.image_refs
if r # 剔除 None
]
def _get_current_image_size(self) -> tuple:
"""Get current image size for coordinate conversion"""
if len(self.observations) > 0:
try:
current_image_bytes = self.observations[-1]["screenshot"]
if isinstance(current_image_bytes, bytes):
current_image = Image.open(BytesIO(current_image_bytes))
return (current_image.height, current_image.width)
except Exception as e:
logger.warning(f"Error getting image size: {e}")
# Fallback to default screen size
return (1080, 1920)
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
image_height, image_width = image_size
# Parse the response to structured actions
parsed_responses = parse_action_to_structure_output(
prediction,
factor=self.action_parse_res_factor,
origin_resized_height=image_height,
origin_resized_width=image_width,
model_type=self.model_type,
max_pixels=self.max_pixels,
min_pixels=self.min_pixels
)
# Convert parsed responses to pyautogui actions
actions = []
for parsed_response in parsed_responses:
try:
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
image_height=image_height,
image_width=image_width,
input_swap=self.input_swap
)
actions.append(pyautogui_code)
except Exception as e:
logger.error(f"Error generating pyautogui code: {e}")
actions.append("FAIL")
return actions
def _check_terminal_actions(self, actions: List[str]) -> str:
"""Check if any action is terminal and return appropriate code"""
for action in actions:
if isinstance(action, dict) and "action_type" in action:
action_type = action["action_type"]
if action_type == FINISH_WORD:
return "DONE"
elif action_type == WAIT_WORD:
return "WAIT"
elif action_type == ENV_FAIL_WORD:
return "FAIL"
elif action_type == CALL_USER:
return "FAIL"
return None

View File

@@ -2,13 +2,13 @@
# Do not write any secret keys or sensitive information here.
# Monitor configuration
TASK_CONFIG_PATH=../evaluation_examples/test_all.json
TASK_CONFIG_PATH=../evaluation_examples/test_nogdrive.json
EXAMPLES_BASE_PATH=../evaluation_examples/examples
RESULTS_BASE_PATH=../results
RESULTS_BASE_PATH=../result_multi_apps_pengxiang_transformers12
# ACTION_SPACE=pyautogui
# OBSERVATION_TYPE=screenshot
# MODEL_NAME=computer-use-preview
# MAX_STEPS=150
FLASK_PORT=80
FLASK_PORT=9001
FLASK_HOST=0.0.0.0
FLASK_DEBUG=false

18
run_dart_gui.sh Normal file
View File

@@ -0,0 +1,18 @@
# export HF_ENDPOINT=https://hf-mirror.com
python run_multienv_dart_gui.py \
--dart_base_url http://0.0.0.0:6006/v1 \
--provider_name docker \
--test_all_meta_path evaluation_examples/test_nogdrive.json \
--path_to_vm docker_vm_data/Ubuntu.qcow2 \
--headless \
--max_steps 30 \
--domain all \
--num_envs 2 \
--log_level INFO \
--temperature 1.0 \
--save_complete_trajectory \
--use_enhanced_runner \
--model dart-gui \
--model_type qwen25vl \
--infer_mode dart_mode \
--result_dir ./result_multi_apps_pengxiang_transformers12 | tee run_20251103_multi_apps_pengxiang_transformers12.log

916
run_multienv_dart_gui.py Normal file
View File

@@ -0,0 +1,916 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import argparse
import datetime
import json
import logging
import os
import sys
import signal
import time
from typing import List
from multiprocessing import Process, Manager, Queue
from multiprocessing import current_process
from numpy import True_
import lib_run_single
from desktop_env.desktop_env import DesktopEnv
from mm_agents.dart_gui_agent import DartAgent
import os
# Global variables for signal handling
active_environments = []
processes = []
is_terminating = False
# load the environment variables from .env file
if os.path.exists(".env"):
from dotenv import load_dotenv
load_dotenv()
# Logger Configs {{{ #
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark - Dart Version"
)
# environment config
parser.add_argument("--path_to_vm", type=str, default=None)
parser.add_argument(
"--headless", action="store_true", help="Run in headless machine"
)
parser.add_argument(
"--action_space", type=str, default="pyautogui", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
default="screenshot",
help="Observation type",
)
parser.add_argument("--sleep_after_execution", type=float, default=5.0)
parser.add_argument("--max_steps", type=int, default=15)
# evaluation config
parser.add_argument(
"--test_config_base_dir", type=str, default="evaluation_examples"
)
# lm config - Dart specific configurations
parser.add_argument("--model", type=str, default="dart-uitars", help="Model name for Dart")
parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"])
parser.add_argument("--infer_mode", type=str, default="dart_mode", choices=["dart_mode", "qwen2vl_user"])
parser.add_argument("--prompt_style", type=str, default="dart_style")
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
parser.add_argument("--language", type=str, default="English")
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
parser.add_argument("--min_pixels", type=float, default=100*28*28)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=-1)
parser.add_argument("--history_n", type=int, default=5)
parser.add_argument("--max_tokens", type=int, default=500)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.")
parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.")
# example config
parser.add_argument("--domain", type=str, default="all")
parser.add_argument(
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
)
# logging related
parser.add_argument("--result_dir", type=str, default="./results")
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
default='INFO', help="Set the logging level")
# aws config
parser.add_argument(
"--region", type=str, default="us-east-1", help="AWS region for the VM"
)
parser.add_argument(
"--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name"
)
parser.add_argument(
"--client_password", type=str, default="password", help="Client password"
)
parser.add_argument(
"--screen_width", type=int, default=1920, help="Screen width"
)
parser.add_argument(
"--screen_height", type=int, default=1080, help="Screen height"
)
# Dart specific parameters
parser.add_argument("--dart_api_key", type=str, default="", help="Dart API key")
parser.add_argument("--dart_base_url", type=str, default="", help="Dart base URL")
parser.add_argument("--max_images", type=int, default=5, help="Maximum number of images in prompt history")
parser.add_argument("--max_texts", type=int, default=35, help="Maximum number of text responses in prompt history")
# Enhanced trajectory saving
parser.add_argument("--save_complete_trajectory", action="store_true", help="Save complete trajectory with images and detailed information")
parser.add_argument("--use_enhanced_runner", action="store_true", help="Use enhanced Dart runner with complete trajectory saving")
args = parser.parse_args()
return args
args = config() # Get command line arguments first
logger = logging.getLogger()
log_level = getattr(logging, args.log_level.upper())
logger.setLevel(log_level)
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
file_handler = logging.FileHandler(
os.path.join("logs", "dart-{:}.log".format(datetime_str)), encoding="utf-8"
)
debug_handler = logging.FileHandler(
os.path.join("logs", "dart-debug-{:}.log".format(datetime_str)), encoding="utf-8"
)
stdout_handler = logging.StreamHandler(sys.stdout)
file_handler.setLevel(logging.INFO)
debug_handler.setLevel(logging.DEBUG)
stdout_handler.setLevel(log_level)
formatter = logging.Formatter(
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
)
file_handler.setFormatter(formatter)
debug_handler.setFormatter(formatter)
stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(logging.Filter("desktopenv"))
logger.addHandler(file_handler)
logger.addHandler(debug_handler)
logger.addHandler(stdout_handler)
# }}} Logger Configs #
logger = logging.getLogger("desktopenv.experiment")
def distribute_tasks(test_all_meta: dict) -> List[tuple]:
all_tasks = []
for domain, examples in test_all_meta.items():
for example_id in examples:
all_tasks.append((domain, example_id))
return all_tasks
def process_signal_handler(signum, frame, env_idx):
"""Signal handler for child processes to gracefully shut down their environments."""
logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...")
# Get the active_environments from the caller's frame
local_vars = frame.f_locals
active_environments = local_vars.get('active_environments', [])
# Close environment in the current process context
for env in active_environments:
if env is not None:
try:
logger.info(f"Process {env_idx + 1} closing environment...")
env.close()
logger.info(f"Process {env_idx + 1} environment closed successfully")
except Exception as e:
logger.error(f"Process {env_idx + 1} error closing environment: {e}")
logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.")
sys.exit(0)
def save_complete_trajectory_with_images(example_result_dir: str, task_info: dict, reward: float,
messages: list, all_images: list = None):
"""
保存完整的轨迹信息,包括图片路径
Args:
example_result_dir: 结果保存目录
task_info: 任务信息
reward: 最终奖励分数
messages: 完整的对话消息
all_images: 所有图片数据列表(可选)
"""
import datetime
# 构建完整轨迹数据
complete_trajectory = {
"task_info": {
"domain": task_info.get("domain", "unknown"),
"example_id": task_info.get("example_id", "unknown"),
"instruction": task_info.get("instruction", ""),
"timestamp": datetime.datetime.now().isoformat()
},
"evaluation": {
"reward": reward,
"success": reward > 0
},
"trajectory": {
"messages": [],
"image_paths": [],
"step_count": 0
}
}
# 处理消息和图片路径
image_counter = 0
step_counter = 0
for msg_idx, message in enumerate(messages):
processed_message = {
"step": step_counter,
"role": message.get("role", "unknown"),
"content": message.get("content", []),
"timestamp": message.get("timestamp", ""),
"image_files": []
}
# 检查消息中的图片内容
if isinstance(message.get("content"), list):
for content_item in message["content"]:
if content_item.get("type") == "image_url":
# 如果有对应的图片数据,保存图片文件
if all_images and image_counter < len(all_images):
image_filename = f"step_{step_counter}_image_{image_counter}.png"
image_path = os.path.join(example_result_dir, image_filename)
try:
# 保存图片
if hasattr(all_images[image_counter], 'save'):
# PIL Image对象
all_images[image_counter].save(image_path)
elif isinstance(all_images[image_counter], bytes):
# 二进制数据
with open(image_path, 'wb') as f:
f.write(all_images[image_counter])
else:
logger.warning(f"Unknown image format for image {image_counter}")
continue
processed_message["image_files"].append(image_filename)
complete_trajectory["trajectory"]["image_paths"].append(image_path)
logger.info(f"Saved image: {image_filename}")
except Exception as e:
logger.error(f"Failed to save image {image_counter}: {e}")
image_counter += 1
# 更新content中的图片引用为本地路径
if processed_message["image_files"]:
content_item["local_path"] = processed_message["image_files"][-1]
complete_trajectory["trajectory"]["messages"].append(processed_message)
# 如果是assistant的回复增加步数
if message.get("role") == "assistant":
step_counter += 1
complete_trajectory["trajectory"]["step_count"] = step_counter
# 保存完整轨迹JSON文件
trajectory_file = os.path.join(example_result_dir, "complete_trajectory.json")
try:
with open(trajectory_file, 'w', encoding='utf-8') as f:
json.dump(complete_trajectory, f, indent=2, ensure_ascii=False)
logger.info(f"Complete trajectory saved to: {trajectory_file}")
# 同时保存一个简化版本用于快速查看
summary_file = os.path.join(example_result_dir, "trajectory_summary.json")
summary = {
"task_id": task_info.get("example_id", "unknown"),
"domain": task_info.get("domain", "unknown"),
"instruction": task_info.get("instruction", ""),
"reward": reward,
"success": reward > 0,
"total_steps": step_counter,
"total_images": len(complete_trajectory["trajectory"]["image_paths"]),
"image_files": [os.path.basename(path) for path in complete_trajectory["trajectory"]["image_paths"]],
"timestamp": complete_trajectory["task_info"]["timestamp"]
}
with open(summary_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, indent=2, ensure_ascii=False)
logger.info(f"Trajectory summary saved to: {summary_file}")
except Exception as e:
logger.error(f"Failed to save complete trajectory: {e}")
def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list):
active_environments = []
env = None
try:
# Initialize proxy configuration if enabled
# if hasattr(args, 'proxy_host') and args.proxy_host and args.proxy_port:
# from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool
# proxy_pool = get_global_proxy_pool()
# proxy_pool.add_proxy(
# host=args.proxy_host,
# port=args.proxy_port,
# protocol=args.proxy_protocol
# )
# logger.info(f"Added proxy: {args.proxy_host}:{args.proxy_port} ({args.proxy_protocol})")
# elif hasattr(args, 'proxy_config') and args.proxy_config and os.path.exists(args.proxy_config):
# from desktop_env.providers.aws.proxy_pool import init_proxy_pool
# init_proxy_pool(args.proxy_config)
# logger.info(f"Initialized proxy pool from {args.proxy_config}")
# Configure environment based on provider
if args.provider_name == "aws":
from desktop_env.providers.aws.manager import IMAGE_ID_MAP
REGION = args.region
screen_size = (args.screen_width, args.screen_height)
ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)])
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
region=REGION,
snapshot_name=ami_id,
screen_size=screen_size,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
else:
# For non-AWS providers (docker, virtualbox, etc.)
env = DesktopEnv(
path_to_vm=args.path_to_vm,
action_space=args.action_space,
provider_name=args.provider_name,
headless=args.headless,
os_type="Ubuntu",
require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"]
)
active_environments.append(env)
args.max_trajectory_length = args.max_steps
# Dart specific runtime configuration
if args.infer_mode == "dart_mode":
runtime_conf: dict = {
"infer_mode": args.infer_mode,
"prompt_style": args.prompt_style,
"input_swap": args.input_swap,
"language": args.language,
"history_n": args.history_n,
"max_pixels": args.max_pixels,
"min_pixels": args.min_pixels,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"max_tokens": args.max_tokens,
"max_images": args.max_images,
"max_texts": args.max_texts,
"dart_api_key": args.dart_api_key,
"dart_base_url": args.dart_base_url
}
elif args.infer_mode == "qwen2vl_user":
runtime_conf: dict = {
"infer_mode": "qwen2vl_user",
"prompt_style": "qwen2vl_user",
"input_swap": args.input_swap,
"language": args.language,
"history_n": 5,
"max_pixels": 2116800,
"min_pixels": 3136,
"temperature": 0.0,
"top_k": -1,
"top_p": 0.9,
"max_tokens": 1000
}
else:
raise ValueError(f"Unknown infer_mode: {args.infer_mode}")
agent = DartAgent(
model=args.model,
action_space=args.action_space,
observation_type=args.observation_type,
max_trajectory_length=args.max_trajectory_length,
model_type=args.model_type,
runtime_conf=runtime_conf
)
logger.info(f"Process {current_process().name} started with Dart configuration.")
while True:
try:
item = task_queue.get(timeout=5)
except Exception:
break
domain, example_id = item
try:
config_file = os.path.join(
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
)
with open(config_file, "r", encoding="utf-8") as f:
example = json.load(f)
logger.info(f"[{current_process().name}][Domain]: {domain}")
logger.info(f"[{current_process().name}][Example ID]: {example_id}")
logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}")
example_result_dir = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
domain,
example_id,
)
os.makedirs(example_result_dir, exist_ok=True)
try:
# Create a temporary list to capture the score
temp_scores = []
# 根据参数选择使用哪个运行函数
if args.use_enhanced_runner or args.save_complete_trajectory:
# 使用九章专用的运行函数,支持完整轨迹保存
logger.info(f"Using enhanced Dart runner for {domain}/{example_id}")
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
temp_scores,
)
else:
# 使用标准运行函数
lib_run_single.run_single_example(
agent,
env,
example,
args.max_steps,
example["instruction"],
args,
example_result_dir,
temp_scores,
)
# Add domain info to the score
if temp_scores:
shared_scores.append({
'domain': domain,
'example_id': example_id,
'score': temp_scores[-1]
})
except Exception as e:
import traceback
logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}")
logger.error(traceback.format_exc())
try:
env.controller.end_recording(
os.path.join(example_result_dir, "recording.mp4")
)
except Exception as rec_e:
logger.error(f"Failed to end recording: {rec_e}")
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
f.write(
json.dumps(
{"Error": f"{domain}/{example_id} - {e}"}
)
)
f.write("\n")
except Exception as e:
logger.error(f"Task-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
except Exception as e:
logger.error(f"Process-level error in {current_process().name}: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
logger.info(f"{current_process().name} cleaning up environment...")
try:
if env:
env.close()
logger.info(f"{current_process().name} environment closed successfully")
except Exception as e:
logger.error(f"{current_process().name} error during environment cleanup: {e}")
def signal_handler(signum, frame):
"""Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments."""
global is_terminating, active_environments, processes
# Avoid duplicate handling
if is_terminating:
return
is_terminating = True
logger.info(f"Received signal {signum}. Gracefully shutting down...")
# Close all registered environments in the main process
for env in active_environments:
try:
logger.info(f"Closing environment...")
env.close()
logger.info(f"Environment closed successfully")
except Exception as e:
logger.error(f"Error closing environment: {e}")
# Send termination signal to all child processes first
for p in processes:
if p.is_alive():
try:
logger.info(f"Sending termination signal to process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error sending termination signal to process: {e}")
# Allow a short time for processes to handle their own cleanup
time.sleep(1)
# Forcefully terminate any processes that didn't exit
for p in processes:
if p.is_alive():
try:
logger.info(f"Forcefully terminating process {p.name}...")
import signal as sig
os.kill(p.pid, sig.SIGKILL)
except Exception as e:
logger.error(f"Error forcefully terminating process: {e}")
logger.info("Shutdown complete. Exiting.")
sys.exit(0)
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
global processes
logger.info("Args: %s", args)
all_tasks = distribute_tasks(test_all_meta)
logger.info(f"Total tasks: {len(all_tasks)}")
with Manager() as manager:
shared_scores = manager.list()
task_queue = manager.Queue()
for item in all_tasks:
task_queue.put(item)
num_envs = args.num_envs
processes = []
for i in range(num_envs):
p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"DartEnvProcess-{i+1}"
)
p.daemon = True
p.start()
processes.append(p)
logger.info(f"Started Dart process {p.name} with PID {p.pid}")
try:
while True:
alive_count = 0
for idx, p in enumerate(processes):
if not p.is_alive():
logger.warning(f"Process {p.name} died, restarting...")
new_p = Process(
target=run_env_tasks,
args=(task_queue, args, shared_scores),
name=f"DartEnvProcess-Restart-{idx+1}"
)
new_p.daemon = True
new_p.start()
processes[idx] = new_p
logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}")
else:
alive_count += 1
if task_queue.empty():
logger.info("All tasks finished.")
break
if alive_count == 0:
logger.error("All processes died, exiting.")
break
time.sleep(5)
for p in processes:
p.join()
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...")
raise
except Exception as e:
logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True)
for p in processes:
if p.is_alive():
try:
logger.info(f"Terminating process {p.name} due to error...")
p.terminate()
except Exception as term_e:
logger.error(f"Error terminating process {p.name}: {term_e}")
raise
scores = list(shared_scores)
# Detailed statistics reporting
if scores:
# Extract numeric scores for overall statistics
numeric_scores = []
domain_stats = {}
for score_entry in scores:
if isinstance(score_entry, dict):
domain = score_entry.get('domain', 'unknown')
example_id = score_entry.get('example_id', 'unknown')
score = score_entry.get('score', 0)
else:
# Handle legacy numeric scores
domain = 'unknown'
example_id = 'unknown'
score = score_entry
numeric_scores.append(score)
# Domain statistics
if domain not in domain_stats:
domain_stats[domain] = {'total': 0, 'success': 0, 'scores': []}
domain_stats[domain]['total'] += 1
domain_stats[domain]['scores'].append(score)
if score > 0:
domain_stats[domain]['success'] += 1
# Overall statistics
total_tasks = len(numeric_scores)
successful_tasks = sum(1 for score in numeric_scores if score > 0)
average_score = sum(numeric_scores) / total_tasks
success_rate = (successful_tasks / total_tasks) * 100
logger.info("=" * 60)
logger.info("📊 DART EVALUATION RESULTS SUMMARY")
logger.info("=" * 60)
logger.info(f"📈 Overall Statistics:")
logger.info(f" • Total tasks executed: {total_tasks}")
logger.info(f" • Successful tasks (score > 0): {successful_tasks}")
logger.info(f" • Success rate: {success_rate:.1f}%")
logger.info(f" • Average score: {average_score:.3f}")
# Domain-specific statistics
if domain_stats and len(domain_stats) > 1: # Only show domain breakdown if multiple domains
logger.info(f"\n🏷️ Domain-specific Results:")
for domain, stats in sorted(domain_stats.items()):
domain_success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0
domain_avg_score = sum(stats['scores']) / len(stats['scores']) if stats['scores'] else 0
logger.info(f"{domain}:")
logger.info(f" - Tasks: {stats['total']}")
logger.info(f" - Successful: {stats['success']}")
logger.info(f" - Success rate: {domain_success_rate:.1f}%")
logger.info(f" - Average score: {domain_avg_score:.3f}")
# Score distribution
score_ranges = {
'Perfect (1.0)': sum(1 for s in numeric_scores if s == 1.0),
'High (0.8-0.99)': sum(1 for s in numeric_scores if 0.8 <= s < 1.0),
'Medium (0.5-0.79)': sum(1 for s in numeric_scores if 0.5 <= s < 0.8),
'Low (0.1-0.49)': sum(1 for s in numeric_scores if 0.1 <= s < 0.5),
'Failed (0.0)': sum(1 for s in numeric_scores if s == 0.0)
}
logger.info(f"\n📊 Score Distribution:")
for range_name, count in score_ranges.items():
if count > 0:
percentage = (count / total_tasks) * 100
logger.info(f"{range_name}: {count} tasks ({percentage:.1f}%)")
logger.info("=" * 60)
else:
logger.warning("⚠️ No scores collected during evaluation!")
def get_unfinished(
action_space, use_model, observation_type, result_dir, total_file_json
):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
return total_file_json
finished = {}
for domain in os.listdir(target_dir):
finished[domain] = []
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
if example_id == "onboard":
continue
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" not in os.listdir(example_path):
# empty all files under example_id
for file in os.listdir(example_path):
os.remove(os.path.join(example_path, file))
else:
finished[domain].append(example_id)
if not finished:
return total_file_json
for domain, examples in finished.items():
if domain in total_file_json:
total_file_json[domain] = [
x for x in total_file_json[domain] if x not in examples
]
return total_file_json
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
if not os.path.exists(target_dir):
print("New experiment, no result yet.")
return None
all_result = []
for domain in os.listdir(target_dir):
domain_path = os.path.join(target_dir, domain)
if os.path.isdir(domain_path):
for example_id in os.listdir(domain_path):
example_path = os.path.join(domain_path, example_id)
if os.path.isdir(example_path):
if "result.txt" in os.listdir(example_path):
# empty all files under example_id
try:
all_result.append(
float(
open(
os.path.join(example_path, "result.txt"), "r"
).read()
)
)
except:
all_result.append(0.0)
if not all_result:
print("New experiment, no result yet.")
return None
else:
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
return all_result
def clear_cache_directory():
"""清空cache目录中的所有内容"""
cache_dir = "cache"
if os.path.exists(cache_dir):
logger.info(f"Clearing cache directory: {cache_dir}")
try:
import shutil
# 删除整个cache目录
shutil.rmtree(cache_dir)
# 重新创建空的cache目录
os.makedirs(cache_dir, exist_ok=True)
logger.info("Cache directory cleared successfully")
except Exception as e:
logger.error(f"Failed to clear cache directory: {e}")
else:
logger.info("Cache directory does not exist, creating it")
os.makedirs(cache_dir, exist_ok=True)
def cleanup_docker_containers():
"""清理Docker容器保留monitor容器"""
logger.info("Cleaning up Docker containers...")
try:
import subprocess
# 获取所有容器ID排除monitor-monitor-1
cmd = 'docker ps --format "{{.ID}} {{.Names}}" | grep -v "monitor-monitor-1" | awk \'{print $1}\''
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
if result.returncode == 0 and result.stdout.strip():
container_ids = result.stdout.strip().split('\n')
container_ids = [cid for cid in container_ids if cid.strip()]
if container_ids:
logger.info(f"Found {len(container_ids)} containers to remove: {container_ids}")
# 强制删除容器
for container_id in container_ids:
try:
rm_result = subprocess.run(
f"docker rm -f {container_id}",
shell=True,
capture_output=True,
text=True,
timeout=10
)
if rm_result.returncode == 0:
logger.info(f"Successfully removed container: {container_id}")
else:
logger.warning(f"Failed to remove container {container_id}: {rm_result.stderr}")
except subprocess.TimeoutExpired:
logger.warning(f"Timeout removing container: {container_id}")
except Exception as e:
logger.error(f"Error removing container {container_id}: {e}")
logger.info("Docker container cleanup completed")
else:
logger.info("No containers found to remove")
else:
logger.info("No containers found or error getting container list")
except subprocess.TimeoutExpired:
logger.error("Timeout during Docker container cleanup")
except Exception as e:
logger.error(f"Failed to cleanup Docker containers: {e}")
if __name__ == "__main__":
####### Dart Version - Complete evaluation runner #######
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Register signal handlers for graceful termination
signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal
try:
args = config()
# 清理Docker容器
# 清除上一次存留的docker 容器 自己跑的时候要留着
cleanup_docker_containers()
# 清空cache目录 清除上一次下载的文件
clear_cache_directory()
logger.info("Starting Dart evaluation runner...")
# save args to json in result_dir/action_space/observation_type/model/args.json
path_to_args = os.path.join(
args.result_dir,
args.action_space,
args.observation_type,
args.model,
"args.json",
)
os.makedirs(os.path.dirname(path_to_args), exist_ok=True)
with open(path_to_args, "w", encoding="utf-8") as f:
json.dump(vars(args), f, indent=4)
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
test_all_meta = json.load(f)
if args.domain != "all":
test_all_meta = {args.domain: test_all_meta[args.domain]}
test_file_list = get_unfinished(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
left_info = ""
for domain in test_file_list:
left_info += f"{domain}: {len(test_file_list[domain])}\n"
logger.info(f"Left tasks:\n{left_info}")
get_result(
args.action_space,
args.model,
args.observation_type,
args.result_dir,
test_all_meta,
)
test(args, test_file_list)
except KeyboardInterrupt:
logger.info("Main process received KeyboardInterrupt.")
# Signal handler will take care of cleanup
except Exception as e:
logger.error(f"Unexpected error in main process: {e}", exc_info=True)
# Also trigger cleanup for unhandled exceptions
signal_handler(signal.SIGTERM, None)
finally:
# Final cleanup in case any environments or processes remain
logger.info("Main process final cleanup...")
for env in active_environments:
if env is not None:
try:
logger.info(f"Closing environment in final cleanup...")
env.close()
logger.info(f"Environment closed successfully in final cleanup")
except Exception as e:
logger.error(f"Error during final environment cleanup: {e}")
# First try gentle termination
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Terminating process {p.name}...")
p.terminate()
except Exception as e:
logger.error(f"Error terminating process: {e}")
# Wait a moment for processes to terminate
time.sleep(1)
# Then force kill if needed
for p in processes:
if p is not None and p.is_alive():
try:
logger.info(f"Force killing process {p.name}...")
os.kill(p.pid, signal.SIGKILL)
logger.info(f"Process {p.name} force killed")
except Exception as e:
logger.error(f"Error force killing process: {e}")