From 00b6468eb7ac1a5f690997c5d699ced079d0d6b3 Mon Sep 17 00:00:00 2001 From: Pengxiang-Li <51403237+Pengxiang-Li@users.noreply.github.com> Date: Fri, 7 Nov 2025 21:50:01 +0800 Subject: [PATCH] feat/dart_gui (#371) --- .gitignore | 3 +- mm_agents/dart_gui/prompts.py | 161 ++++++ mm_agents/dart_gui/task_loader.py | 202 +++++++ mm_agents/dart_gui/utils.py | 511 +++++++++++++++++ mm_agents/dart_gui_agent.py | 686 ++++++++++++++++++++++ monitor/.env | 6 +- run_dart_gui.sh | 18 + run_multienv_dart_gui.py | 916 ++++++++++++++++++++++++++++++ 8 files changed, 2499 insertions(+), 4 deletions(-) create mode 100644 mm_agents/dart_gui/prompts.py create mode 100644 mm_agents/dart_gui/task_loader.py create mode 100644 mm_agents/dart_gui/utils.py create mode 100644 mm_agents/dart_gui_agent.py create mode 100644 run_dart_gui.sh create mode 100644 run_multienv_dart_gui.py diff --git a/.gitignore b/.gitignore index 8a90002..506e0bd 100644 --- a/.gitignore +++ b/.gitignore @@ -204,4 +204,5 @@ reference/ draft/ manual_examine.py run_human_examine.sh -quick_start.py \ No newline at end of file +quick_start.py +result_multi_apps_pengxiang_transformers12 \ No newline at end of file diff --git a/mm_agents/dart_gui/prompts.py b/mm_agents/dart_gui/prompts.py new file mode 100644 index 0000000..50f9df1 --- /dev/null +++ b/mm_agents/dart_gui/prompts.py @@ -0,0 +1,161 @@ +COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space + +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. +- My computer's password is 'password', feel free to use it when you need sudo rights. + +## User Instruction +{instruction} +""" + +COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space + +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. +call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. +- My computer's password is 'password', feel free to use it when you need sudo rights. + +## User Instruction +{instruction} +""" + +UITARS_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +""" + +UITARS_CALL_USR_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. +""" + +UITARS_NORMAL_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. +""" + +UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. +## Output Format +``` +Action: ... +``` +## Action Space +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. +## User Instruction +{instruction} +""" + +UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space +{action_space} + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. + +## User Instruction +{instruction} +""" + + +FAILURE_INDICATORS = [ + # Direct inability expressions + "无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法", + + # Regret/apology expressions + "遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起", + + # Not supported/available + "不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在", + + # Cannot access/handle + "无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也", + + # Not possible/feasible + "不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置", + + # System limitations + "超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题", + + # Refusal indicators + "拒绝", "不允许", "禁止", "不合适", "不恰当", + + # Trying Restart + "从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启", +] diff --git a/mm_agents/dart_gui/task_loader.py b/mm_agents/dart_gui/task_loader.py new file mode 100644 index 0000000..c13dbe0 --- /dev/null +++ b/mm_agents/dart_gui/task_loader.py @@ -0,0 +1,202 @@ +import asyncio +from typing import List, Optional, Union, Dict, Any +import json +import os +import hashlib +from pathlib import Path +from omegaconf import DictConfig +from dataclasses import dataclass, asdict +import copy +import logging +import random + +from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER +from log_config import setup_logging + +# 设置统一的日志系统 +setup_logging() +logger = logging.getLogger(__name__) + +class TaskLoader: + def __init__(self, task_cfg: DictConfig, storage_root): + self.task_file = Path(task_cfg.task_file) + #self.task_root = Path(task_cfg.task_root) + self.osworld_root = Path(task_cfg.osworld_root) + + self._latest_sha: Optional[str] = None + self.storage_root = storage_root + self.resume = task_cfg.resume + + def poll_for_tasks(self) -> List[Dict]: + """find new tasks json file + return list of TaskInfo dict if there is new json + else return [] + """ + self._maybe_refresh_dataset() + + tasks_list = [task.to_dict() for task in self._tasks] + random.shuffle(tasks_list) + + return tasks_list + + def _maybe_refresh_dataset_bak(self): + + # check new json + latest_json = self._find_latest_json() + + if latest_json is None: + return False # no json file + + sha = self._calc_sha1(latest_json) + if sha == self._latest_sha: + return False # no change + + with open(latest_json) as f: + data = json.load(f) + + raw_tasks = [ + {"task_type": task_type, "task_id": task_id} + for task_type, task_ids in data.items() + for task_id in task_ids + ] + + self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks] + self._latest_sha = sha + + logger.info(f"当前任务文件: {str(latest_json)}") + logger.info(f"任务总数: {len(raw_tasks)}") + + return True + + def _maybe_refresh_dataset(self): + + latest_json = self.task_file + print("Current tasks file: ", str(latest_json)) + + with open(latest_json) as f: + data = json.load(f) + + raw_tasks = [ + {"task_type": task_type, "task_id": task_id} + for task_type, task_ids in data.items() + for task_id in task_ids + ] + + if self.resume: + # 过滤已完成或类型不匹配的任务 + filtered_tasks = [] + storage_root = Path(self.storage_root) + + for raw in raw_tasks: + task_id = str(raw["task_id"]) + task_type_expected = raw["task_type"] + + # 找到所有以 task_id 开头的子目录(允许有多个版本) + candidate_dirs = [ + d for d in storage_root.iterdir() + if d.is_dir() and d.name.startswith(task_id) + ] + + # 默认认为任务未完成 + task_finished = False + + for d in candidate_dirs: + cfg_path = d / "task_config.json" + if not cfg_path.exists(): + print("找不到config文件") + continue + + try: + with cfg_path.open("r", encoding="utf-8") as cf: + cfg = json.load(cf) + except Exception: + print("配置损坏,忽略此目录") + continue + + # 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录 + if cfg.get("raw", {}).get("task_type") != task_type_expected: + continue + + # 3.2 task_type 相同,检查 reward.txt + if (d / "reward.txt").exists(): + task_finished = True + break # 已找到完成记录,无需再看其他目录 + if not task_finished: + filtered_tasks.append(raw) + self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks] + print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}") + + else: + self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks] + print(f"Total number of tasks: {len(raw_tasks)}") + + return True + + def _find_latest_json(self) -> Optional[Path]: + files = list(self.task_root.glob("*.json")) + return max(files, key=lambda p: p.stat().st_mtime) if files else None + + @staticmethod + def _calc_sha1(fp: Path, chunk_size=2<<20) -> str: + h = hashlib.sha1() + with fp.open("rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + h.update(chunk) + return h.hexdigest() + + +@dataclass +class TaskInfo: + messages: List + instruction: str + task_config: Dict + + def to_dict(self): + return asdict(self) + + +def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo: + + task_type = raw["task_type"] + task_id = raw["task_id"] + task_path = os.path.join(osworld_root, task_type, task_id + ".json") + with open(task_path) as f: + task_data = json.load(f) + + task_data["raw"] = { + "task_type": task_type, + "task_id": task_id + } + + instruction = task_data["instruction"] + + if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]: + plan = task_data["human-ground-truth"]["single-action"] + plan_text = "\n".join(plan) + instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text + + system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER + messages = [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": system_prompt.format( + instruction=instruction, + language="English" + )} + ] + } + ] + + + return TaskInfo( + messages = messages, + instruction = instruction, + task_config = task_data + ) \ No newline at end of file diff --git a/mm_agents/dart_gui/utils.py b/mm_agents/dart_gui/utils.py new file mode 100644 index 0000000..e94206e --- /dev/null +++ b/mm_agents/dart_gui/utils.py @@ -0,0 +1,511 @@ +import ast +import base64 +import logging +import math +import re +import xml.etree.ElementTree as ET +from io import BytesIO +from typing import Dict, List + +import numpy as np +import openai + +from openai import OpenAI +from PIL import Image +from requests.exceptions import SSLError +from mm_agents.dart_gui.prompts import FAILURE_INDICATORS + +# 设置日志系统 +logger = logging.getLogger(__name__) + +FINISH_WORD = "finished" +WAIT_WORD = "wait" +ENV_FAIL_WORD = "error_env" +CALL_USER = "call_user" + +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +pure_text_settings = ["a11y_tree"] + +attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes" +attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes" +state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state" +state_ns_windows = "https://accessibility.windows.example.org/ns/state" +component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component" +component_ns_windows = "https://accessibility.windows.example.org/ns/component" +value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value" +value_ns_windows = "https://accessibility.windows.example.org/ns/value" +class_ns_windows = "https://accessibility.windows.example.org/ns/class" +# More namespaces defined in OSWorld, please check desktop_env/server/main.py + +# 定义一个函数来解析每个 action +def parse_action(action_str): + try: + # 解析字符串为 AST 节点 + node = ast.parse(action_str, mode='eval') + + # 确保节点是一个表达式 + if not isinstance(node, ast.Expression): + raise ValueError("Not an expression") + + # 获取表达式的主体 + call = node.body + + # 确保主体是一个函数调用 + if not isinstance(call, ast.Call): + raise ValueError("Not a function call") + + # 获取函数名 + if isinstance(call.func, ast.Name): + func_name = call.func.id + elif isinstance(call.func, ast.Attribute): + func_name = call.func.attr + else: + func_name = None + + # 获取关键字参数 + kwargs = {} + for kw in call.keywords: + key = kw.arg + # 处理不同类型的值,这里假设都是常量 + if isinstance(kw.value, ast.Constant): + value = kw.value.value + elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python + value = kw.value.s + else: + value = None + kwargs[key] = value + + return { + 'function': func_name, + 'args': kwargs + } + + except Exception as e: + logger.error(f"Failed to parse action '{action_str}': {e}") + return None + +def escape_single_quotes(text): + # 匹配未转义的单引号(不匹配 \\') + pattern = r"(? int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + +def linear_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + if width * height > max_pixels: + """ + 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 + """ + resize_factor = math.sqrt(max_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + if width * height < min_pixels: + resize_factor = math.sqrt(min_pixels / (width * height)) + width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor) + + return height, width + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + +def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28): + text = text.strip() + if model_type == "qwen25vl": + smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) + + # 正则表达式匹配 Action 字符串 + if text.startswith("Thought:"): + thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" + thought_hint = "Thought: " + elif text.startswith("Reflection:"): + thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)" + thought_hint = "Reflection: " + elif text.startswith("Action_Summary:"): + thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)" + thought_hint = "Action_Summary: " + else: + thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" + thought_hint = "Thought: " + reflection, thought = None, None + thought_match = re.search(thought_pattern, text, re.DOTALL) + if thought_match: + if len(thought_match.groups()) == 1: + thought = thought_match.group(1).strip() + elif len(thought_match.groups()) == 2: + thought = thought_match.group(2).strip() + reflection = thought_match.group(1).strip() + assert "Action:" in text + action_str = text.split("Action:")[-1] + + tmp_all_action = action_str.split("\n\n") + all_action = [] + for action_str in tmp_all_action: + if "type(content" in action_str: + # 正则表达式匹配 content 中的字符串并转义单引号 + def escape_quotes(match): + content = match.group(1) # 获取 content 的值 + return content + + # 使用正则表达式进行替换 + pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...') + content = re.sub(pattern, escape_quotes, action_str) + + # 处理字符串 + action_str = escape_single_quotes(content) + action_str = "type(content='" + action_str + "')" + + if "finished(content" in action_str: + # 正则表达式匹配 content 中的字符串并转义单引号 + def escape_quotes(match): + content = match.group(1) # 获取 content 的值 + return content + + # 使用正则表达式进行替换 + pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...') + content = re.sub(pattern, escape_quotes, action_str) + + # 处理字符串 + action_str = escape_single_quotes(content) + action_str = "finished(content='" + action_str + "')" + all_action.append(action_str) + + parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action] + actions = [] + for action_instance, raw_str in zip(parsed_actions, all_action): + if action_instance == None: + logger.error(f"Action can't parse: {raw_str}") + # raise ValueError(f"Action can't parse: {raw_str}") + continue + action_type = action_instance["function"] + params = action_instance["args"] + + # import pdb; pdb.set_trace() + action_inputs = {} + for param_name, param in params.items(): + if param == "": continue + param = param.lstrip() # 去掉引号和多余的空格 + # 处理start_box或者end_box参数格式 'x1 y1 x2 y2' + action_inputs[param_name.strip()] = param + + if "start_box" in param_name or "end_box" in param_name: + ori_box = param + # Remove parentheses and split the string by commas + numbers = ori_box.replace("(", "").replace(")", "").split(",") + + # Convert to float and scale by 1000 + # Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates + if model_type == "qwen25vl": + float_numbers = [] + for num_idx, num in enumerate(numbers): + num = float(num) + if (num_idx + 1) % 2 == 0: + float_numbers.append(float(num/smart_resize_height)) + else: + float_numbers.append(float(num/smart_resize_width)) + else: + float_numbers = [float(num) / factor for num in numbers] + + if len(float_numbers) == 2: + float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] + action_inputs[param_name.strip()] = str(float_numbers) + + # import pdb; pdb.set_trace() + actions.append( + { + "reflection": reflection, + "thought": thought, + "action_type": action_type, + "action_inputs": action_inputs, + "text": text + }) + return actions + +def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str: + ''' + 将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串 + 参数: + response: 包含模型输出的字典,结构类似于: + { + "action_type": "hotkey", + "action_inputs": { + "hotkey": "v ctrl", + "start_box": None, + "end_box": None + } + } + 返回: + 生成的pyautogui代码字符串 + ''' + + pyautogui_code = "import pyautogui\nimport time\n" + if isinstance(responses, dict): + responses = [responses] + for response_id, response in enumerate(responses): + if "observation" in response: + observation = response["observation"] + else: + observation = "" + + if "thought" in response: + thought = response["thought"] + else: + thought = "" + + if response_id == 0: + pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" + else: + pyautogui_code += "\ntime.sleep(1)\n" + + action_dict = response + response_text = action_dict.get("text", "") + action_type = action_dict.get("action_type") + action_inputs = action_dict.get("action_inputs", {}) + + if action_type == "hotkey": + # Parsing hotkey action + if "key" in action_inputs: + hotkey = action_inputs.get("key", "") + else: + hotkey = action_inputs.get("hotkey", "") + + if hotkey == "arrowleft": + hotkey = "left" + + elif hotkey == "arrowright": + hotkey = "right" + + elif hotkey == "arrowup": + hotkey = "up" + + elif hotkey == "arrowdown": + hotkey = "down" + + if hotkey: + # Handle other hotkeys + keys = hotkey.split() # Split the keys by space + convert_keys = [] + for key in keys: + if key == "space": + key = ' ' + convert_keys.append(key) + pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})" + + elif action_type == "press": + # Parsing press action + if "key" in action_inputs: + key_to_press = action_inputs.get("key", "") + else: + key_to_press = action_inputs.get("press", "") + + if hotkey == "arrowleft": + hotkey = "left" + + elif hotkey == "arrowright": + hotkey = "right" + + elif hotkey == "arrowup": + hotkey = "up" + + elif hotkey == "arrowdown": + hotkey = "down" + + elif hotkey == "space": + hotkey = " " + + if key_to_press: + # Simulate pressing a single key + pyautogui_code += f"\npyautogui.press({repr(key_to_press)})" + + elif action_type == "keyup": + key_to_up = action_inputs.get("key", "") + pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})" + + elif action_type == "keydown": + key_to_down = action_inputs.get("key", "") + pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})" + + elif action_type == "type": + # Parsing typing action using clipboard + content = action_inputs.get("content", "") + content = escape_single_quotes(content) + stripped_content = content + if content.endswith("\n") or content.endswith("\\n"): + stripped_content = stripped_content.rstrip("\\n").rstrip("\n") + if content: + if input_swap: + pyautogui_code += "\nimport pyperclip" + pyautogui_code += f"\npyperclip.copy('{stripped_content}')" + pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')" + pyautogui_code += "\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += "\npyautogui.press('enter')" + else: + pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" + pyautogui_code += "\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += "\npyautogui.press('enter')" + + + elif action_type in ["drag", "select"]: + # Parsing drag or select action based on start and end_boxes + start_box = action_inputs.get("start_box") + end_box = action_inputs.get("end_box") + if start_box and end_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + sx = round(float((x1 + x2) / 2) * image_width, 3) + sy = round(float((y1 + y2) / 2) * image_height, 3) + x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2] + ex = round(float((x1 + x2) / 2) * image_width, 3) + ey = round(float((y1 + y2) / 2) * image_height, 3) + pyautogui_code += ( + f"\npyautogui.moveTo({sx}, {sy})\n" + f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n" + ) + + elif action_type == "scroll": + # Parsing scroll action + start_box = action_inputs.get("start_box") + if start_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + + # # 先点对应区域,再滚动 + # pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + else: + x = None + y = None + direction = action_inputs.get("direction", "") + + if x == None: + if "up" in direction.lower(): + pyautogui_code += "\npyautogui.scroll(5)" + elif "down" in direction.lower(): + pyautogui_code += "\npyautogui.scroll(-5)" + else: + if "up" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" + elif "down" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})" + + elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]: + # Parsing mouse click actions + start_box = action_inputs.get("start_box") + start_box = str(start_box) + if start_box: + start_box = eval(start_box) + if start_box is None: + logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}") + + if len(start_box) == 4: + x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2] + elif len(start_box) == 2: + x1, y1 = start_box + x2 = x1 + y2 = y1 + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + if action_type == "left_single" or action_type == "click": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + elif action_type == "left_double": + pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')" + elif action_type == "right_single": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')" + elif action_type == "hover": + pyautogui_code += f"\npyautogui.moveTo({x}, {y})" + + elif action_type in ["finished"]: + pyautogui_code = "DONE" + print(f"FINISHED:response_text: {response_text}") + print(f"FINISHED:response: {str(response)}") + for failure_indicator in FAILURE_INDICATORS: + if failure_indicator in response_text: + pyautogui_code = "FAIL" + break + elif action_type in ["wait"]: + pyautogui_code = "WAIT" + + elif action_type in ["call_user"]: + pyautogui_code = "FAIL" + else: + pyautogui_code += f"\n# Unrecognized action type: {action_type}" + + return pyautogui_code + +def add_box_token(input_string): + # Step 1: Split the string into individual actions + if "Action: " in input_string and "start_box=" in input_string: + suffix = input_string.split("Action: ")[0] + "Action: " + actions = input_string.split("Action: ")[1:] + processed_actions = [] + for action in actions: + action = action.strip() + # Step 2: Extract coordinates (start_box or end_box) using regex + coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action) + + updated_action = action # Start with the original action + for coord_type, x, y in coordinates: + # Convert x and y to integers + updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'") + processed_actions.append(updated_action) + + # Step 5: Reconstruct the final string + final_string = suffix + "\n\n".join(processed_actions) + else: + final_string = input_string + # print(f"Input string: {input_string}") + # print(f"Final string: {final_string}") + return [{"type": "text", "text": final_string}] + +def pil_to_base64(image): + """Convert PIL Image or bytes to base64 string""" + if isinstance(image, bytes): + # If it's already bytes, just encode to base64 + return base64.b64encode(image).decode("utf-8") + else: + # If it's a PIL Image, convert it + buffer = BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") \ No newline at end of file diff --git a/mm_agents/dart_gui_agent.py b/mm_agents/dart_gui_agent.py new file mode 100644 index 0000000..2f8fa6c --- /dev/null +++ b/mm_agents/dart_gui_agent.py @@ -0,0 +1,686 @@ +""" +Dart Agent - Custom agent for GUI automation using Dart models +Based on UITARSAgent structure but using Dart-specific utilities and prompts +""" +import ast +import base64 +import logging +import math +import os +import re +import time +from io import BytesIO +from typing import Dict, List, Any +from PIL import Image +from openai import OpenAI +import backoff +import openai +import requests +from requests.exceptions import SSLError +from google.api_core.exceptions import ( + BadRequest, + InternalServerError, + InvalidArgument, + ResourceExhausted, +) + +# Import Dart-specific utilities and prompts +from mm_agents.dart_gui.utils import ( + pil_to_base64, + parse_action_to_structure_output, + parsing_response_to_pyautogui_code, + parse_action, + escape_single_quotes, + round_by_factor, + ceil_by_factor, + floor_by_factor, + linear_resize, + smart_resize, + add_box_token, + IMAGE_FACTOR, + MIN_PIXELS, + MAX_PIXELS, + MAX_RATIO, + FINISH_WORD, + WAIT_WORD, + ENV_FAIL_WORD, + CALL_USER +) + +from mm_agents.dart_gui.prompts import ( + COMPUTER_USE_PROMPT, + COMPUTER_USE_PROMPT_WITH_CALL_USER, + UITARS_ACTION_SPACE, + UITARS_CALL_USR_ACTION_SPACE, + UITARS_USR_PROMPT_THOUGHT, + UITARS_USR_PROMPT_NOTHOUGHT +) + +logger = logging.getLogger("desktopenv.agent") + +class DartAgent: + def __init__( + self, + model: str, + runtime_conf: Dict, + platform="ubuntu", + max_tokens=1000, + top_p=0.9, + top_k=1.0, + temperature=0.0, + action_space="pyautogui", + observation_type="screenshot", + max_trajectory_length=50, + model_type="qwen25vl", + **kwargs + ): + self.model = model + self.platform = platform + self.action_space = action_space + self.observation_type = observation_type + self.max_trajectory_length = max_trajectory_length + self.model_type = model_type + self.runtime_conf = runtime_conf + + # Extract runtime configuration parameters + self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens) + self.top_p = self.runtime_conf.get("top_p", top_p) + self.top_k = self.runtime_conf.get("top_k", top_k) + self.temperature = self.runtime_conf.get("temperature", temperature) + self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode") + self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style") + self.input_swap = self.runtime_conf.get("input_swap", False) + self.language = self.runtime_conf.get("language", "English") + self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS) + self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS) + self.history_n = self.runtime_conf.get("history_n", 5) + + # Dart specific configurations + self.max_images = self.runtime_conf.get("max_images", 5) + self.max_texts = self.runtime_conf.get("max_texts", 35) + + # Initialize OpenAI client - use Dart API if provided + dart_api_key = self.runtime_conf.get("dart_api_key", "") + dart_base_url = self.runtime_conf.get("dart_base_url", "") + + if dart_base_url: + # 检查是否为直接的生成端点(包含 /generate) + if '/generate' in dart_base_url: + # 直接使用提供的 URL,不添加 /v1 + logger.info(f"使用直接生成端点: {dart_base_url}") + self.dart_direct_url = dart_base_url + self.vlm = None # 不使用 OpenAI 客户端 + else: + # 传统的 OpenAI 兼容端点,确保以 /v1 结尾 + if not dart_base_url.endswith('/v1'): + dart_base_url = dart_base_url.rstrip('/') + '/v1' + + self.vlm = OpenAI( + base_url=dart_base_url, + api_key=dart_api_key, + ) + self.dart_direct_url = None + else: + # Fallback to environment variables + base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL')) + if base_url: + if '/generate' in base_url: + # 直接生成端点 + self.dart_direct_url = base_url + self.vlm = None + else: + if not base_url.endswith('/v1'): + base_url = base_url.rstrip('/') + '/v1' + self.vlm = OpenAI( + base_url=base_url, + api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')), + ) + self.dart_direct_url = None + else: + self.vlm = None + self.dart_direct_url = None + + # Initialize trajectory storage - similar to trajectory_runner.py + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + # Message handling similar to trajectory_runner.py + self.base_messages = [] # for model client (with base64 images) + self.base_messages_for_save = [] # for storage (with file paths) + self.prompt_dialogue = [] # for model client + self.save_dialogue = [] # for storage + self.save_dialogue_full = [] # for full storage (保存所有图片路径) + self.image_refs = [] # record image position + + # All image paths storage - to keep track of all images even when trimmed + self.all_image_paths = [] + + # Current screenshot file path for proper saving + self.current_screenshot_path = None + + # Configure prompt and action space based on mode + if self.infer_mode == "dart_mode": + self.prompt_action_space = UITARS_ACTION_SPACE + self.prompt_template = COMPUTER_USE_PROMPT + else: + # For qwen2vl_user mode + self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE + if self.prompt_style == "qwen2vl_user": + self.prompt_template = UITARS_USR_PROMPT_THOUGHT + elif self.prompt_style == "qwen2vl_no_thought": + self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT + else: + self.prompt_template = UITARS_USR_PROMPT_THOUGHT + + self.action_parse_res_factor = 1000 + + logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}") + + def reset(self, runtime_logger=None): + """Reset the agent state""" + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + # Reset message handling + self.base_messages = [] + self.base_messages_for_save = [] + self.prompt_dialogue = [] + self.save_dialogue = [] + self.save_dialogue_full = [] + self.image_refs = [] + self.all_image_paths = [] + self.current_screenshot_path = None + + logger.info("DartAgent reset") + + def set_base_messages(self, instruction: str): + """Initialize base messages similar to task_loader.py""" + system_prompt = COMPUTER_USE_PROMPT + + self.base_messages = [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": system_prompt.format( + instruction=instruction, + language=self.language + ) + } + ] + } + ] + + # Copy for save version + from copy import deepcopy + self.base_messages_for_save = deepcopy(self.base_messages) + + def set_current_screenshot_path(self, screenshot_path: str): + """Set the current screenshot file path for proper saving""" + self.current_screenshot_path = screenshot_path + + def predict( + self, instruction: str, obs: Dict, last_action_after_obs: Dict = None + ) -> tuple: + """ + Predict the next action(s) based on the current observation. + Returns: (response_text, actions_list) + """ + # Initialize base messages if not set + if not self.base_messages: + self.set_base_messages(instruction) + + # Store current observation + self._add_observation(obs) + + # For first step, set the first frame + if len(self.observations) == 1: + self._set_first_frame(obs["screenshot"], self.current_screenshot_path) + else: + # For subsequent steps, add the new image to dialogue + # This represents the result of the previous action + self._add_image(obs["screenshot"], self.current_screenshot_path) + + # Build prompt messages (base_messages + prompt_dialogue) + messages = self._build_messages() + + # Call model to get response + prediction = self._call_model(messages) + if prediction is None: + return "client error", ["DONE"] + + # Store response and parse actions + self._add_text(prediction) + + # Parse response to actions + try: + image_size = self._get_current_image_size() + actions = self._parse_and_convert_actions(prediction, image_size) + + # Check for terminal actions + terminal_action = self._check_terminal_actions(actions) + if terminal_action: + self.actions.append(actions) + return prediction, [terminal_action] + + except Exception as e: + logger.error(f"Parsing action error: {prediction}, error: {e}") + return f"Parsing action error: {prediction}, error: {e}", ["DONE"] + + self.actions.append(actions) + # Check max steps + if len(self.history_responses) >= self.max_trajectory_length: + actions = ["FAIL"] + + return prediction, actions + + @backoff.on_exception( + backoff.constant, + ( + # General exceptions + SSLError, + # OpenAI exceptions + openai.RateLimitError, + openai.BadRequestError, + openai.InternalServerError, + # Google exceptions + InvalidArgument, + ResourceExhausted, + InternalServerError, + BadRequest, + ), + interval=30, + max_tries=10, + ) + def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None): + """Predict with backoff for rate limiting and temporary errors""" + return self.predict(instruction, obs, last_action_after_obs) + + def get_trajectory(self) -> List[Dict]: + """Get the current trajectory for saving""" + trajectory = [] + for i in range(len(self.observations)): + trajectory.append({ + "observation": self.observations[i], + "thought": self.thoughts[i] if i < len(self.thoughts) else "", + "action": self.actions[i] if i < len(self.actions) else [] + }) + return trajectory + + def get_full_messages(self) -> List[Dict]: + """Get the complete conversation messages for saving (including base messages and dialogue)""" + # Combine base_messages_for_save with save_dialogue_full to get complete conversation + full_messages = [] + + # Add base messages (system prompt and initial user message) + full_messages.extend(self.base_messages_for_save) + + # Add dialogue messages (user images + assistant responses) with all images + full_messages.extend(self.save_dialogue_full) + + return full_messages + + def get_all_image_paths(self) -> List[str]: + """Get all image paths that have been used throughout the conversation""" + return self.all_image_paths.copy() + + # ========== Private Methods ========== + + def _validate_trajectory(self): + """Validate trajectory consistency""" + assert len(self.observations) == len(self.actions) and len(self.actions) == len( + self.thoughts + ), "The number of observations and actions should be the same." + + def _add_observation(self, obs: Dict): + """Process observation and add to history""" + # Store observation + if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: + base64_image = obs["screenshot"] + try: + # Handle accessibility tree if needed + linearized_accessibility_tree = None + if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs: + # For now, we'll skip accessibility tree processing in Dart mode + linearized_accessibility_tree = None + except: + linearized_accessibility_tree = None + + if self.observation_type == "screenshot_a11y_tree": + self.observations.append({ + "screenshot": base64_image, + "accessibility_tree": linearized_accessibility_tree, + }) + else: + self.observations.append({ + "screenshot": base64_image, + "accessibility_tree": None + }) + else: + raise ValueError("Invalid observation_type type: " + self.observation_type) + + + def _build_messages(self) -> List[Dict]: + """Build messages for model API call - similar to trajectory_runner._build_messages""" + return self.base_messages + self.prompt_dialogue + + def _call_model(self, messages: List[Dict]) -> str: + """Call model with retry logic""" + try_times = 3 + while try_times > 0: + try: + # 如果使用直接生成端点 + if hasattr(self, 'dart_direct_url') and self.dart_direct_url: + prediction = self._call_direct_generate_endpoint(messages) + else: + # 使用标准 OpenAI 客户端 + response = self.vlm.chat.completions.create( + model=self.model, + messages=messages, + frequency_penalty=1, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p + ) + prediction = response.choices[0].message.content + + logger.info(f"Model response: {prediction}") + return prediction + + except Exception as e: + logger.error(f"Error when fetching response from client: {e}") + try_times -= 1 + if try_times <= 0: + logger.error("Reach max retry times to fetch response from client") + return None + return None + + def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str: + """直接调用生成端点""" + try: + + + # 构建请求数据 + payload = { + "messages": messages, + "model": self.model, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "frequency_penalty": 1 + } + + # 添加 API key 到 headers + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}" + } + + + # 重试机制:最多重试3次,每次推理60秒 + max_retries = 3 + response = None + + for attempt in range(max_retries): + try: + logger.info(f"尝试第 {attempt + 1} 次请求...") + response = requests.post( + self.dart_direct_url, + json=payload, + headers=headers, + timeout=60 + ) + response.raise_for_status() + break # 成功则跳出重试循环 + except Exception as e: + logger.warning(f"第 {attempt + 1} 次请求失败: {e}") + if attempt == max_retries - 1: # 最后一次重试失败 + logger.error(f"所有 {max_retries} 次重试都失败了") + raise e + else: + logger.info(f"等待后重试...") + import time + time.sleep(2) # 等待2秒后重试 + + # 解析响应 + result = response.json() + + # 尝试多种可能的响应格式 + if 'choices' in result and len(result['choices']) > 0: + # OpenAI 兼容格式 + return result['choices'][0]['message']['content'] + elif 'response' in result: + # 简单的 response 字段 + return result['response'] + elif 'text' in result: + # text 字段 + return result['text'] + elif 'content' in result: + # content 字段 + return result['content'] + else: + # 如果找不到标准字段,返回整个响应的字符串 + logger.warning(f"未知的响应格式: {result}") + return str(result) + + except Exception as e: + logger.error(f"直接端点调用失败: {e}") + raise e + + def _add_text(self, assistant_txt: str): + """Add text response to history - similar to trajectory_runner.py""" + self.history_responses.append(assistant_txt) + self.thoughts.append(assistant_txt) + + # Add to dialogue similar to trajectory_runner._add_text + msg = { + "role": "assistant", + "content": add_box_token(assistant_txt) + } + self.prompt_dialogue.append(msg) + self.save_dialogue.append(msg) + self.save_dialogue_full.append(msg) + self._trim() + + def _set_first_frame(self, obs_img: bytes, frame_path: str = None): + """Set first frame in base_messages - similar to trajectory_runner._set_first_frame""" + self.base_messages[1]["content"].append( + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)} + } + ) + + # Use actual frame path if provided, otherwise use current_screenshot_path or placeholder + if frame_path: + first_frame_path = frame_path + elif self.current_screenshot_path: + first_frame_path = self.current_screenshot_path + else: + first_frame_path = "first_frame.png" + + # Store in all_image_paths + self.all_image_paths.append(first_frame_path) + + self.base_messages_for_save[1]["content"].append( + { + "type": "image_url", + "image_url": first_frame_path + } + ) + + self.image_refs.append( + {"source": "base", "msg_idx": 1, + "content_idx": len(self.base_messages[1]["content"]) - 1} + ) + + def _add_image(self, img_bytes: bytes, frame_path: str = None): + """Add image to dialogue - similar to trajectory_runner._add_image""" + self.prompt_dialogue.append({ + "role": "user", + "content": [{ + "type": "image_url", + "image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)} + }] + }) + + # Use actual frame path if provided, otherwise use current_screenshot_path + if frame_path: + image_url = frame_path + elif self.current_screenshot_path: + image_url = self.current_screenshot_path + else: + # Fallback to a placeholder - this should rarely happen in practice + image_url = f"frame_{len(self.save_dialogue)}.png" + + # Store in all_image_paths for complete record + self.all_image_paths.append(image_url) + + # Add to save_dialogue (trimmed version) + self.save_dialogue.append({ + "role": "user", + "content": [{ + "type": "image_url", + "image_url": image_url + }] + }) + + # Add to save_dialogue_full (complete version - never trimmed) + self.save_dialogue_full.append({ + "role": "user", + "content": [{ + "type": "image_url", + "image_url": image_url + }] + }) + + self.image_refs.append( + {"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1, + "content_idx": None} + ) + + self._trim() + + def _trim(self): + """Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim""" + img_cnt = len(self.image_refs) + txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue) + + while img_cnt > self.max_images or txt_cnt > self.max_texts: + # 图片超限:最早一张 + if img_cnt > self.max_images: + ref = self.image_refs.pop(0) + if ref["source"] == "base": + self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"]) + else: # dialogue 图 + self._remove_dialogue_msg(ref["msg_idx"]) + img_cnt -= 1 + continue + + # 文本超限:最早 assistant 文本 + if txt_cnt > self.max_texts: + for i, m in enumerate(self.prompt_dialogue): + if m["role"] == "assistant": + self._remove_dialogue_msg(i) + txt_cnt -= 1 + break + + def _remove_dialogue_msg(self, idx: int): + """Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg""" + self.prompt_dialogue.pop(idx) + self.save_dialogue.pop(idx) + # Note: save_dialogue_full is never trimmed, so we don't remove from it + + # 更新 image_refs + self.image_refs = [ + r if not (r["source"] == "dialogue" and r["msg_idx"] == idx) + else None # 同一条被删掉的图引用直接丢弃 + for r in self.image_refs + ] + self.image_refs = [ + ( + {**r, "msg_idx": r["msg_idx"] - 1} + if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1 + else r + ) + for r in self.image_refs + if r # 剔除 None + ] + + def _get_current_image_size(self) -> tuple: + """Get current image size for coordinate conversion""" + if len(self.observations) > 0: + try: + current_image_bytes = self.observations[-1]["screenshot"] + if isinstance(current_image_bytes, bytes): + current_image = Image.open(BytesIO(current_image_bytes)) + return (current_image.height, current_image.width) + except Exception as e: + logger.warning(f"Error getting image size: {e}") + + # Fallback to default screen size + return (1080, 1920) + + def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]: + """Parse response and convert to pyautogui actions - similar to trajectory_runner._parse""" + image_height, image_width = image_size + + # Parse the response to structured actions + parsed_responses = parse_action_to_structure_output( + prediction, + factor=self.action_parse_res_factor, + origin_resized_height=image_height, + origin_resized_width=image_width, + model_type=self.model_type, + max_pixels=self.max_pixels, + min_pixels=self.min_pixels + ) + + # Convert parsed responses to pyautogui actions + actions = [] + for parsed_response in parsed_responses: + try: + pyautogui_code = parsing_response_to_pyautogui_code( + parsed_response, + image_height=image_height, + image_width=image_width, + input_swap=self.input_swap + ) + + + + actions.append(pyautogui_code) + + except Exception as e: + logger.error(f"Error generating pyautogui code: {e}") + actions.append("FAIL") + + return actions + + + + def _check_terminal_actions(self, actions: List[str]) -> str: + """Check if any action is terminal and return appropriate code""" + for action in actions: + if isinstance(action, dict) and "action_type" in action: + action_type = action["action_type"] + if action_type == FINISH_WORD: + return "DONE" + elif action_type == WAIT_WORD: + return "WAIT" + elif action_type == ENV_FAIL_WORD: + return "FAIL" + elif action_type == CALL_USER: + return "FAIL" + return None diff --git a/monitor/.env b/monitor/.env index 4bdadf9..c6c1036 100644 --- a/monitor/.env +++ b/monitor/.env @@ -2,13 +2,13 @@ # Do not write any secret keys or sensitive information here. # Monitor configuration -TASK_CONFIG_PATH=../evaluation_examples/test_all.json +TASK_CONFIG_PATH=../evaluation_examples/test_nogdrive.json EXAMPLES_BASE_PATH=../evaluation_examples/examples -RESULTS_BASE_PATH=../results +RESULTS_BASE_PATH=../result_multi_apps_pengxiang_transformers12 # ACTION_SPACE=pyautogui # OBSERVATION_TYPE=screenshot # MODEL_NAME=computer-use-preview # MAX_STEPS=150 -FLASK_PORT=80 +FLASK_PORT=9001 FLASK_HOST=0.0.0.0 FLASK_DEBUG=false diff --git a/run_dart_gui.sh b/run_dart_gui.sh new file mode 100644 index 0000000..2c01ef9 --- /dev/null +++ b/run_dart_gui.sh @@ -0,0 +1,18 @@ +# export HF_ENDPOINT=https://hf-mirror.com +python run_multienv_dart_gui.py \ + --dart_base_url http://0.0.0.0:6006/v1 \ + --provider_name docker \ + --test_all_meta_path evaluation_examples/test_nogdrive.json \ + --path_to_vm docker_vm_data/Ubuntu.qcow2 \ + --headless \ + --max_steps 30 \ + --domain all \ + --num_envs 2 \ + --log_level INFO \ + --temperature 1.0 \ + --save_complete_trajectory \ + --use_enhanced_runner \ + --model dart-gui \ + --model_type qwen25vl \ + --infer_mode dart_mode \ + --result_dir ./result_multi_apps_pengxiang_transformers12 | tee run_20251103_multi_apps_pengxiang_transformers12.log \ No newline at end of file diff --git a/run_multienv_dart_gui.py b/run_multienv_dart_gui.py new file mode 100644 index 0000000..499012b --- /dev/null +++ b/run_multienv_dart_gui.py @@ -0,0 +1,916 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List +from multiprocessing import Process, Manager, Queue +from multiprocessing import current_process + +from numpy import True_ +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.dart_gui_agent import DartAgent +import os + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark - Dart Version" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=5.0) + parser.add_argument("--max_steps", type=int, default=15) + + # evaluation config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config - Dart specific configurations + parser.add_argument("--model", type=str, default="dart-uitars", help="Model name for Dart") + parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"]) + parser.add_argument("--infer_mode", type=str, default="dart_mode", choices=["dart_mode", "qwen2vl_user"]) + parser.add_argument("--prompt_style", type=str, default="dart_style") + parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content") + parser.add_argument("--language", type=str, default="English") + parser.add_argument("--max_pixels", type=float, default=16384*28*28) + parser.add_argument("--min_pixels", type=float, default=100*28*28) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top_p", type=float, default=1.0) + parser.add_argument("--top_k", type=int, default=-1) + parser.add_argument("--history_n", type=int, default=5) + parser.add_argument("--max_tokens", type=int, default=500) + parser.add_argument("--stop_token", type=str, default=None) + + parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.") + parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.") + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="password", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + + # Dart specific parameters + parser.add_argument("--dart_api_key", type=str, default="", help="Dart API key") + parser.add_argument("--dart_base_url", type=str, default="", help="Dart base URL") + parser.add_argument("--max_images", type=int, default=5, help="Maximum number of images in prompt history") + parser.add_argument("--max_texts", type=int, default=35, help="Maximum number of text responses in prompt history") + + # Enhanced trajectory saving + parser.add_argument("--save_complete_trajectory", action="store_true", help="Save complete trajectory with images and detailed information") + parser.add_argument("--use_enhanced_runner", action="store_true", help="Use enhanced Dart runner with complete trajectory saving") + + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "dart-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "dart-debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + +def save_complete_trajectory_with_images(example_result_dir: str, task_info: dict, reward: float, + messages: list, all_images: list = None): + """ + 保存完整的轨迹信息,包括图片路径 + + Args: + example_result_dir: 结果保存目录 + task_info: 任务信息 + reward: 最终奖励分数 + messages: 完整的对话消息 + all_images: 所有图片数据列表(可选) + """ + import datetime + + # 构建完整轨迹数据 + complete_trajectory = { + "task_info": { + "domain": task_info.get("domain", "unknown"), + "example_id": task_info.get("example_id", "unknown"), + "instruction": task_info.get("instruction", ""), + "timestamp": datetime.datetime.now().isoformat() + }, + "evaluation": { + "reward": reward, + "success": reward > 0 + }, + "trajectory": { + "messages": [], + "image_paths": [], + "step_count": 0 + } + } + + # 处理消息和图片路径 + image_counter = 0 + step_counter = 0 + + for msg_idx, message in enumerate(messages): + processed_message = { + "step": step_counter, + "role": message.get("role", "unknown"), + "content": message.get("content", []), + "timestamp": message.get("timestamp", ""), + "image_files": [] + } + + # 检查消息中的图片内容 + if isinstance(message.get("content"), list): + for content_item in message["content"]: + if content_item.get("type") == "image_url": + # 如果有对应的图片数据,保存图片文件 + if all_images and image_counter < len(all_images): + image_filename = f"step_{step_counter}_image_{image_counter}.png" + image_path = os.path.join(example_result_dir, image_filename) + + try: + # 保存图片 + if hasattr(all_images[image_counter], 'save'): + # PIL Image对象 + all_images[image_counter].save(image_path) + elif isinstance(all_images[image_counter], bytes): + # 二进制数据 + with open(image_path, 'wb') as f: + f.write(all_images[image_counter]) + else: + logger.warning(f"Unknown image format for image {image_counter}") + continue + + processed_message["image_files"].append(image_filename) + complete_trajectory["trajectory"]["image_paths"].append(image_path) + logger.info(f"Saved image: {image_filename}") + + except Exception as e: + logger.error(f"Failed to save image {image_counter}: {e}") + + image_counter += 1 + + # 更新content中的图片引用为本地路径 + if processed_message["image_files"]: + content_item["local_path"] = processed_message["image_files"][-1] + + complete_trajectory["trajectory"]["messages"].append(processed_message) + + # 如果是assistant的回复,增加步数 + if message.get("role") == "assistant": + step_counter += 1 + + complete_trajectory["trajectory"]["step_count"] = step_counter + + # 保存完整轨迹JSON文件 + trajectory_file = os.path.join(example_result_dir, "complete_trajectory.json") + try: + with open(trajectory_file, 'w', encoding='utf-8') as f: + json.dump(complete_trajectory, f, indent=2, ensure_ascii=False) + logger.info(f"Complete trajectory saved to: {trajectory_file}") + + # 同时保存一个简化版本用于快速查看 + summary_file = os.path.join(example_result_dir, "trajectory_summary.json") + summary = { + "task_id": task_info.get("example_id", "unknown"), + "domain": task_info.get("domain", "unknown"), + "instruction": task_info.get("instruction", ""), + "reward": reward, + "success": reward > 0, + "total_steps": step_counter, + "total_images": len(complete_trajectory["trajectory"]["image_paths"]), + "image_files": [os.path.basename(path) for path in complete_trajectory["trajectory"]["image_paths"]], + "timestamp": complete_trajectory["task_info"]["timestamp"] + } + + with open(summary_file, 'w', encoding='utf-8') as f: + json.dump(summary, f, indent=2, ensure_ascii=False) + logger.info(f"Trajectory summary saved to: {summary_file}") + + except Exception as e: + logger.error(f"Failed to save complete trajectory: {e}") + + +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + # Initialize proxy configuration if enabled + # if hasattr(args, 'proxy_host') and args.proxy_host and args.proxy_port: + # from desktop_env.providers.aws.proxy_pool import get_global_proxy_pool + # proxy_pool = get_global_proxy_pool() + # proxy_pool.add_proxy( + # host=args.proxy_host, + # port=args.proxy_port, + # protocol=args.proxy_protocol + # ) + # logger.info(f"Added proxy: {args.proxy_host}:{args.proxy_port} ({args.proxy_protocol})") + # elif hasattr(args, 'proxy_config') and args.proxy_config and os.path.exists(args.proxy_config): + # from desktop_env.providers.aws.proxy_pool import init_proxy_pool + # init_proxy_pool(args.proxy_config) + # logger.info(f"Initialized proxy pool from {args.proxy_config}") + + # Configure environment based on provider + if args.provider_name == "aws": + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"] + ) + else: + # For non-AWS providers (docker, virtualbox, etc.) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"] + ) + active_environments.append(env) + args.max_trajectory_length = args.max_steps + + # Dart specific runtime configuration + if args.infer_mode == "dart_mode": + runtime_conf: dict = { + "infer_mode": args.infer_mode, + "prompt_style": args.prompt_style, + "input_swap": args.input_swap, + "language": args.language, + "history_n": args.history_n, + "max_pixels": args.max_pixels, + "min_pixels": args.min_pixels, + "temperature": args.temperature, + "top_k": args.top_k, + "top_p": args.top_p, + "max_tokens": args.max_tokens, + "max_images": args.max_images, + "max_texts": args.max_texts, + "dart_api_key": args.dart_api_key, + "dart_base_url": args.dart_base_url + } + elif args.infer_mode == "qwen2vl_user": + runtime_conf: dict = { + "infer_mode": "qwen2vl_user", + "prompt_style": "qwen2vl_user", + "input_swap": args.input_swap, + "language": args.language, + "history_n": 5, + "max_pixels": 2116800, + "min_pixels": 3136, + "temperature": 0.0, + "top_k": -1, + "top_p": 0.9, + "max_tokens": 1000 + } + else: + raise ValueError(f"Unknown infer_mode: {args.infer_mode}") + + agent = DartAgent( + model=args.model, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + model_type=args.model_type, + runtime_conf=runtime_conf + ) + + logger.info(f"Process {current_process().name} started with Dart configuration.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + # Create a temporary list to capture the score + temp_scores = [] + + # 根据参数选择使用哪个运行函数 + if args.use_enhanced_runner or args.save_complete_trajectory: + # 使用九章专用的运行函数,支持完整轨迹保存 + logger.info(f"Using enhanced Dart runner for {domain}/{example_id}") + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + temp_scores, + ) + else: + # 使用标准运行函数 + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + temp_scores, + ) + # Add domain info to the score + if temp_scores: + shared_scores.append({ + 'domain': domain, + 'example_id': example_id, + 'score': temp_scores[-1] + }) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") + with Manager() as manager: + shared_scores = manager.list() + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs + processes = [] + for i in range(num_envs): + p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"DartEnvProcess-{i+1}" + ) + p.daemon = True + p.start() + processes.append(p) + logger.info(f"Started Dart process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"DartEnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + scores = list(shared_scores) + + # Detailed statistics reporting + if scores: + # Extract numeric scores for overall statistics + numeric_scores = [] + domain_stats = {} + + for score_entry in scores: + if isinstance(score_entry, dict): + domain = score_entry.get('domain', 'unknown') + example_id = score_entry.get('example_id', 'unknown') + score = score_entry.get('score', 0) + else: + # Handle legacy numeric scores + domain = 'unknown' + example_id = 'unknown' + score = score_entry + + numeric_scores.append(score) + + # Domain statistics + if domain not in domain_stats: + domain_stats[domain] = {'total': 0, 'success': 0, 'scores': []} + + domain_stats[domain]['total'] += 1 + domain_stats[domain]['scores'].append(score) + if score > 0: + domain_stats[domain]['success'] += 1 + + # Overall statistics + total_tasks = len(numeric_scores) + successful_tasks = sum(1 for score in numeric_scores if score > 0) + average_score = sum(numeric_scores) / total_tasks + success_rate = (successful_tasks / total_tasks) * 100 + + logger.info("=" * 60) + logger.info("📊 DART EVALUATION RESULTS SUMMARY") + logger.info("=" * 60) + logger.info(f"📈 Overall Statistics:") + logger.info(f" • Total tasks executed: {total_tasks}") + logger.info(f" • Successful tasks (score > 0): {successful_tasks}") + logger.info(f" • Success rate: {success_rate:.1f}%") + logger.info(f" • Average score: {average_score:.3f}") + + # Domain-specific statistics + if domain_stats and len(domain_stats) > 1: # Only show domain breakdown if multiple domains + logger.info(f"\n🏷️ Domain-specific Results:") + for domain, stats in sorted(domain_stats.items()): + domain_success_rate = (stats['success'] / stats['total']) * 100 if stats['total'] > 0 else 0 + domain_avg_score = sum(stats['scores']) / len(stats['scores']) if stats['scores'] else 0 + logger.info(f" • {domain}:") + logger.info(f" - Tasks: {stats['total']}") + logger.info(f" - Successful: {stats['success']}") + logger.info(f" - Success rate: {domain_success_rate:.1f}%") + logger.info(f" - Average score: {domain_avg_score:.3f}") + + # Score distribution + score_ranges = { + 'Perfect (1.0)': sum(1 for s in numeric_scores if s == 1.0), + 'High (0.8-0.99)': sum(1 for s in numeric_scores if 0.8 <= s < 1.0), + 'Medium (0.5-0.79)': sum(1 for s in numeric_scores if 0.5 <= s < 0.8), + 'Low (0.1-0.49)': sum(1 for s in numeric_scores if 0.1 <= s < 0.5), + 'Failed (0.0)': sum(1 for s in numeric_scores if s == 0.0) + } + + logger.info(f"\n📊 Score Distribution:") + for range_name, count in score_ranges.items(): + if count > 0: + percentage = (count / total_tasks) * 100 + logger.info(f" • {range_name}: {count} tasks ({percentage:.1f}%)") + + logger.info("=" * 60) + else: + logger.warning("⚠️ No scores collected during evaluation!") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +def clear_cache_directory(): + """清空cache目录中的所有内容""" + cache_dir = "cache" + if os.path.exists(cache_dir): + logger.info(f"Clearing cache directory: {cache_dir}") + try: + import shutil + # 删除整个cache目录 + shutil.rmtree(cache_dir) + # 重新创建空的cache目录 + os.makedirs(cache_dir, exist_ok=True) + logger.info("Cache directory cleared successfully") + except Exception as e: + logger.error(f"Failed to clear cache directory: {e}") + else: + logger.info("Cache directory does not exist, creating it") + os.makedirs(cache_dir, exist_ok=True) + + +def cleanup_docker_containers(): + """清理Docker容器,保留monitor容器""" + logger.info("Cleaning up Docker containers...") + try: + import subprocess + + # 获取所有容器ID,排除monitor-monitor-1 + cmd = 'docker ps --format "{{.ID}} {{.Names}}" | grep -v "monitor-monitor-1" | awk \'{print $1}\'' + result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30) + + if result.returncode == 0 and result.stdout.strip(): + container_ids = result.stdout.strip().split('\n') + container_ids = [cid for cid in container_ids if cid.strip()] + + if container_ids: + logger.info(f"Found {len(container_ids)} containers to remove: {container_ids}") + + # 强制删除容器 + for container_id in container_ids: + try: + rm_result = subprocess.run( + f"docker rm -f {container_id}", + shell=True, + capture_output=True, + text=True, + timeout=10 + ) + if rm_result.returncode == 0: + logger.info(f"Successfully removed container: {container_id}") + else: + logger.warning(f"Failed to remove container {container_id}: {rm_result.stderr}") + except subprocess.TimeoutExpired: + logger.warning(f"Timeout removing container: {container_id}") + except Exception as e: + logger.error(f"Error removing container {container_id}: {e}") + + logger.info("Docker container cleanup completed") + else: + logger.info("No containers found to remove") + else: + logger.info("No containers found or error getting container list") + + except subprocess.TimeoutExpired: + logger.error("Timeout during Docker container cleanup") + except Exception as e: + logger.error(f"Failed to cleanup Docker containers: {e}") + + +if __name__ == "__main__": + ####### Dart Version - Complete evaluation runner ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # 清理Docker容器 + # 清除上一次存留的docker 容器 自己跑的时候要留着 + cleanup_docker_containers() + + # 清空cache目录 清除上一次下载的文件 + clear_cache_directory() + + logger.info("Starting Dart evaluation runner...") + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") \ No newline at end of file