feat/dart_gui (#371)
This commit is contained in:
161
mm_agents/dart_gui/prompts.py
Normal file
161
mm_agents/dart_gui/prompts.py
Normal file
@@ -0,0 +1,161 @@
|
||||
COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
- My computer's password is 'password', feel free to use it when you need sudo rights.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
- My computer's password is 'password', feel free to use it when you need sudo rights.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
"""
|
||||
|
||||
UITARS_CALL_USR_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
"""
|
||||
|
||||
UITARS_NORMAL_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
## Output Format
|
||||
```
|
||||
Action: ...
|
||||
```
|
||||
## Action Space
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished()
|
||||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
|
||||
FAILURE_INDICATORS = [
|
||||
# Direct inability expressions
|
||||
"无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法",
|
||||
|
||||
# Regret/apology expressions
|
||||
"遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起",
|
||||
|
||||
# Not supported/available
|
||||
"不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在",
|
||||
|
||||
# Cannot access/handle
|
||||
"无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也",
|
||||
|
||||
# Not possible/feasible
|
||||
"不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置",
|
||||
|
||||
# System limitations
|
||||
"超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题",
|
||||
|
||||
# Refusal indicators
|
||||
"拒绝", "不允许", "禁止", "不合适", "不恰当",
|
||||
|
||||
# Trying Restart
|
||||
"从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启",
|
||||
]
|
||||
202
mm_agents/dart_gui/task_loader.py
Normal file
202
mm_agents/dart_gui/task_loader.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import json
|
||||
import os
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
from omegaconf import DictConfig
|
||||
from dataclasses import dataclass, asdict
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
|
||||
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
|
||||
from log_config import setup_logging
|
||||
|
||||
# 设置统一的日志系统
|
||||
setup_logging()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TaskLoader:
|
||||
def __init__(self, task_cfg: DictConfig, storage_root):
|
||||
self.task_file = Path(task_cfg.task_file)
|
||||
#self.task_root = Path(task_cfg.task_root)
|
||||
self.osworld_root = Path(task_cfg.osworld_root)
|
||||
|
||||
self._latest_sha: Optional[str] = None
|
||||
self.storage_root = storage_root
|
||||
self.resume = task_cfg.resume
|
||||
|
||||
def poll_for_tasks(self) -> List[Dict]:
|
||||
"""find new tasks json file
|
||||
return list of TaskInfo dict if there is new json
|
||||
else return []
|
||||
"""
|
||||
self._maybe_refresh_dataset()
|
||||
|
||||
tasks_list = [task.to_dict() for task in self._tasks]
|
||||
random.shuffle(tasks_list)
|
||||
|
||||
return tasks_list
|
||||
|
||||
def _maybe_refresh_dataset_bak(self):
|
||||
|
||||
# check new json
|
||||
latest_json = self._find_latest_json()
|
||||
|
||||
if latest_json is None:
|
||||
return False # no json file
|
||||
|
||||
sha = self._calc_sha1(latest_json)
|
||||
if sha == self._latest_sha:
|
||||
return False # no change
|
||||
|
||||
with open(latest_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
raw_tasks = [
|
||||
{"task_type": task_type, "task_id": task_id}
|
||||
for task_type, task_ids in data.items()
|
||||
for task_id in task_ids
|
||||
]
|
||||
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
||||
self._latest_sha = sha
|
||||
|
||||
logger.info(f"当前任务文件: {str(latest_json)}")
|
||||
logger.info(f"任务总数: {len(raw_tasks)}")
|
||||
|
||||
return True
|
||||
|
||||
def _maybe_refresh_dataset(self):
|
||||
|
||||
latest_json = self.task_file
|
||||
print("Current tasks file: ", str(latest_json))
|
||||
|
||||
with open(latest_json) as f:
|
||||
data = json.load(f)
|
||||
|
||||
raw_tasks = [
|
||||
{"task_type": task_type, "task_id": task_id}
|
||||
for task_type, task_ids in data.items()
|
||||
for task_id in task_ids
|
||||
]
|
||||
|
||||
if self.resume:
|
||||
# 过滤已完成或类型不匹配的任务
|
||||
filtered_tasks = []
|
||||
storage_root = Path(self.storage_root)
|
||||
|
||||
for raw in raw_tasks:
|
||||
task_id = str(raw["task_id"])
|
||||
task_type_expected = raw["task_type"]
|
||||
|
||||
# 找到所有以 task_id 开头的子目录(允许有多个版本)
|
||||
candidate_dirs = [
|
||||
d for d in storage_root.iterdir()
|
||||
if d.is_dir() and d.name.startswith(task_id)
|
||||
]
|
||||
|
||||
# 默认认为任务未完成
|
||||
task_finished = False
|
||||
|
||||
for d in candidate_dirs:
|
||||
cfg_path = d / "task_config.json"
|
||||
if not cfg_path.exists():
|
||||
print("找不到config文件")
|
||||
continue
|
||||
|
||||
try:
|
||||
with cfg_path.open("r", encoding="utf-8") as cf:
|
||||
cfg = json.load(cf)
|
||||
except Exception:
|
||||
print("配置损坏,忽略此目录")
|
||||
continue
|
||||
|
||||
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
|
||||
if cfg.get("raw", {}).get("task_type") != task_type_expected:
|
||||
continue
|
||||
|
||||
# 3.2 task_type 相同,检查 reward.txt
|
||||
if (d / "reward.txt").exists():
|
||||
task_finished = True
|
||||
break # 已找到完成记录,无需再看其他目录
|
||||
if not task_finished:
|
||||
filtered_tasks.append(raw)
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
|
||||
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
|
||||
|
||||
else:
|
||||
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
|
||||
print(f"Total number of tasks: {len(raw_tasks)}")
|
||||
|
||||
return True
|
||||
|
||||
def _find_latest_json(self) -> Optional[Path]:
|
||||
files = list(self.task_root.glob("*.json"))
|
||||
return max(files, key=lambda p: p.stat().st_mtime) if files else None
|
||||
|
||||
@staticmethod
|
||||
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
|
||||
h = hashlib.sha1()
|
||||
with fp.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(chunk_size), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskInfo:
|
||||
messages: List
|
||||
instruction: str
|
||||
task_config: Dict
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
|
||||
|
||||
task_type = raw["task_type"]
|
||||
task_id = raw["task_id"]
|
||||
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
|
||||
with open(task_path) as f:
|
||||
task_data = json.load(f)
|
||||
|
||||
task_data["raw"] = {
|
||||
"task_type": task_type,
|
||||
"task_id": task_id
|
||||
}
|
||||
|
||||
instruction = task_data["instruction"]
|
||||
|
||||
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
|
||||
plan = task_data["human-ground-truth"]["single-action"]
|
||||
plan_text = "\n".join(plan)
|
||||
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
|
||||
|
||||
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt.format(
|
||||
instruction=instruction,
|
||||
language="English"
|
||||
)}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
return TaskInfo(
|
||||
messages = messages,
|
||||
instruction = instruction,
|
||||
task_config = task_data
|
||||
)
|
||||
511
mm_agents/dart_gui/utils.py
Normal file
511
mm_agents/dart_gui/utils.py
Normal file
@@ -0,0 +1,511 @@
|
||||
import ast
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
from io import BytesIO
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
from requests.exceptions import SSLError
|
||||
from mm_agents.dart_gui.prompts import FAILURE_INDICATORS
|
||||
|
||||
# 设置日志系统
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
pure_text_settings = ["a11y_tree"]
|
||||
|
||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||||
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||||
|
||||
# 定义一个函数来解析每个 action
|
||||
def parse_action(action_str):
|
||||
try:
|
||||
# 解析字符串为 AST 节点
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
|
||||
# 确保节点是一个表达式
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
# 获取表达式的主体
|
||||
call = node.body
|
||||
|
||||
# 确保主体是一个函数调用
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
# 获取函数名
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# 获取关键字参数
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
key = kw.arg
|
||||
# 处理不同类型的值,这里假设都是常量
|
||||
if isinstance(kw.value, ast.Constant):
|
||||
value = kw.value.value
|
||||
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
|
||||
value = kw.value.s
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
|
||||
def escape_single_quotes(text):
|
||||
# 匹配未转义的单引号(不匹配 \\')
|
||||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def linear_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
if width * height > max_pixels:
|
||||
"""
|
||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||
"""
|
||||
resize_factor = math.sqrt(max_pixels / (width * height))
|
||||
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||
if width * height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (width * height))
|
||||
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
|
||||
|
||||
return height, width
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
|
||||
text = text.strip()
|
||||
if model_type == "qwen25vl":
|
||||
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
# 正则表达式匹配 Action 字符串
|
||||
if text.startswith("Thought:"):
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
elif text.startswith("Reflection:"):
|
||||
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Reflection: "
|
||||
elif text.startswith("Action_Summary:"):
|
||||
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Action_Summary: "
|
||||
else:
|
||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||
thought_hint = "Thought: "
|
||||
reflection, thought = None, None
|
||||
thought_match = re.search(thought_pattern, text, re.DOTALL)
|
||||
if thought_match:
|
||||
if len(thought_match.groups()) == 1:
|
||||
thought = thought_match.group(1).strip()
|
||||
elif len(thought_match.groups()) == 2:
|
||||
thought = thought_match.group(2).strip()
|
||||
reflection = thought_match.group(1).strip()
|
||||
assert "Action:" in text
|
||||
action_str = text.split("Action:")[-1]
|
||||
|
||||
tmp_all_action = action_str.split("\n\n")
|
||||
all_action = []
|
||||
for action_str in tmp_all_action:
|
||||
if "type(content" in action_str:
|
||||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||||
def escape_quotes(match):
|
||||
content = match.group(1) # 获取 content 的值
|
||||
return content
|
||||
|
||||
# 使用正则表达式进行替换
|
||||
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
|
||||
# 处理字符串
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
if "finished(content" in action_str:
|
||||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||||
def escape_quotes(match):
|
||||
content = match.group(1) # 获取 content 的值
|
||||
return content
|
||||
|
||||
# 使用正则表达式进行替换
|
||||
pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...')
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
|
||||
# 处理字符串
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "finished(content='" + action_str + "')"
|
||||
all_action.append(action_str)
|
||||
|
||||
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
|
||||
actions = []
|
||||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||
if action_instance == None:
|
||||
logger.error(f"Action can't parse: {raw_str}")
|
||||
# raise ValueError(f"Action can't parse: {raw_str}")
|
||||
continue
|
||||
action_type = action_instance["function"]
|
||||
params = action_instance["args"]
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
if param == "": continue
|
||||
param = param.lstrip() # 去掉引号和多余的空格
|
||||
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
ori_box = param
|
||||
# Remove parentheses and split the string by commas
|
||||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||||
|
||||
# Convert to float and scale by 1000
|
||||
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
|
||||
if model_type == "qwen25vl":
|
||||
float_numbers = []
|
||||
for num_idx, num in enumerate(numbers):
|
||||
num = float(num)
|
||||
if (num_idx + 1) % 2 == 0:
|
||||
float_numbers.append(float(num/smart_resize_height))
|
||||
else:
|
||||
float_numbers.append(float(num/smart_resize_width))
|
||||
else:
|
||||
float_numbers = [float(num) / factor for num in numbers]
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
||||
# import pdb; pdb.set_trace()
|
||||
actions.append(
|
||||
{
|
||||
"reflection": reflection,
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
})
|
||||
return actions
|
||||
|
||||
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
|
||||
'''
|
||||
将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串
|
||||
参数:
|
||||
response: 包含模型输出的字典,结构类似于:
|
||||
{
|
||||
"action_type": "hotkey",
|
||||
"action_inputs": {
|
||||
"hotkey": "v ctrl",
|
||||
"start_box": None,
|
||||
"end_box": None
|
||||
}
|
||||
}
|
||||
返回:
|
||||
生成的pyautogui代码字符串
|
||||
'''
|
||||
|
||||
pyautogui_code = "import pyautogui\nimport time\n"
|
||||
if isinstance(responses, dict):
|
||||
responses = [responses]
|
||||
for response_id, response in enumerate(responses):
|
||||
if "observation" in response:
|
||||
observation = response["observation"]
|
||||
else:
|
||||
observation = ""
|
||||
|
||||
if "thought" in response:
|
||||
thought = response["thought"]
|
||||
else:
|
||||
thought = ""
|
||||
|
||||
if response_id == 0:
|
||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||
else:
|
||||
pyautogui_code += "\ntime.sleep(1)\n"
|
||||
|
||||
action_dict = response
|
||||
response_text = action_dict.get("text", "")
|
||||
action_type = action_dict.get("action_type")
|
||||
action_inputs = action_dict.get("action_inputs", {})
|
||||
|
||||
if action_type == "hotkey":
|
||||
# Parsing hotkey action
|
||||
if "key" in action_inputs:
|
||||
hotkey = action_inputs.get("key", "")
|
||||
else:
|
||||
hotkey = action_inputs.get("hotkey", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
if hotkey:
|
||||
# Handle other hotkeys
|
||||
keys = hotkey.split() # Split the keys by space
|
||||
convert_keys = []
|
||||
for key in keys:
|
||||
if key == "space":
|
||||
key = ' '
|
||||
convert_keys.append(key)
|
||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
|
||||
|
||||
elif action_type == "press":
|
||||
# Parsing press action
|
||||
if "key" in action_inputs:
|
||||
key_to_press = action_inputs.get("key", "")
|
||||
else:
|
||||
key_to_press = action_inputs.get("press", "")
|
||||
|
||||
if hotkey == "arrowleft":
|
||||
hotkey = "left"
|
||||
|
||||
elif hotkey == "arrowright":
|
||||
hotkey = "right"
|
||||
|
||||
elif hotkey == "arrowup":
|
||||
hotkey = "up"
|
||||
|
||||
elif hotkey == "arrowdown":
|
||||
hotkey = "down"
|
||||
|
||||
elif hotkey == "space":
|
||||
hotkey = " "
|
||||
|
||||
if key_to_press:
|
||||
# Simulate pressing a single key
|
||||
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
|
||||
|
||||
elif action_type == "keyup":
|
||||
key_to_up = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
|
||||
|
||||
elif action_type == "keydown":
|
||||
key_to_down = action_inputs.get("key", "")
|
||||
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
|
||||
|
||||
elif action_type == "type":
|
||||
# Parsing typing action using clipboard
|
||||
content = action_inputs.get("content", "")
|
||||
content = escape_single_quotes(content)
|
||||
stripped_content = content
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
|
||||
if content:
|
||||
if input_swap:
|
||||
pyautogui_code += "\nimport pyperclip"
|
||||
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
else:
|
||||
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||
pyautogui_code += "\ntime.sleep(0.5)\n"
|
||||
if content.endswith("\n") or content.endswith("\\n"):
|
||||
pyautogui_code += "\npyautogui.press('enter')"
|
||||
|
||||
|
||||
elif action_type in ["drag", "select"]:
|
||||
# Parsing drag or select action based on start and end_boxes
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
if start_box and end_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
sx = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
sy = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
ex = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
ey = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
pyautogui_code += (
|
||||
f"\npyautogui.moveTo({sx}, {sy})\n"
|
||||
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||||
)
|
||||
|
||||
elif action_type == "scroll":
|
||||
# Parsing scroll action
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
|
||||
# # 先点对应区域,再滚动
|
||||
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
else:
|
||||
x = None
|
||||
y = None
|
||||
direction = action_inputs.get("direction", "")
|
||||
|
||||
if x == None:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += "\npyautogui.scroll(5)"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += "\npyautogui.scroll(-5)"
|
||||
else:
|
||||
if "up" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
|
||||
elif "down" in direction.lower():
|
||||
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
|
||||
|
||||
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
|
||||
# Parsing mouse click actions
|
||||
start_box = action_inputs.get("start_box")
|
||||
start_box = str(start_box)
|
||||
if start_box:
|
||||
start_box = eval(start_box)
|
||||
if start_box is None:
|
||||
logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}")
|
||||
|
||||
if len(start_box) == 4:
|
||||
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
|
||||
elif len(start_box) == 2:
|
||||
x1, y1 = start_box
|
||||
x2 = x1
|
||||
y2 = y1
|
||||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||||
if action_type == "left_single" or action_type == "click":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||||
elif action_type == "left_double":
|
||||
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
|
||||
elif action_type == "right_single":
|
||||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
|
||||
elif action_type == "hover":
|
||||
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
|
||||
|
||||
elif action_type in ["finished"]:
|
||||
pyautogui_code = "DONE"
|
||||
print(f"FINISHED:response_text: {response_text}")
|
||||
print(f"FINISHED:response: {str(response)}")
|
||||
for failure_indicator in FAILURE_INDICATORS:
|
||||
if failure_indicator in response_text:
|
||||
pyautogui_code = "FAIL"
|
||||
break
|
||||
elif action_type in ["wait"]:
|
||||
pyautogui_code = "WAIT"
|
||||
|
||||
elif action_type in ["call_user"]:
|
||||
pyautogui_code = "FAIL"
|
||||
else:
|
||||
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
|
||||
|
||||
return pyautogui_code
|
||||
|
||||
def add_box_token(input_string):
|
||||
# Step 1: Split the string into individual actions
|
||||
if "Action: " in input_string and "start_box=" in input_string:
|
||||
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||
actions = input_string.split("Action: ")[1:]
|
||||
processed_actions = []
|
||||
for action in actions:
|
||||
action = action.strip()
|
||||
# Step 2: Extract coordinates (start_box or end_box) using regex
|
||||
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||
|
||||
updated_action = action # Start with the original action
|
||||
for coord_type, x, y in coordinates:
|
||||
# Convert x and y to integers
|
||||
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
|
||||
processed_actions.append(updated_action)
|
||||
|
||||
# Step 5: Reconstruct the final string
|
||||
final_string = suffix + "\n\n".join(processed_actions)
|
||||
else:
|
||||
final_string = input_string
|
||||
# print(f"Input string: {input_string}")
|
||||
# print(f"Final string: {final_string}")
|
||||
return [{"type": "text", "text": final_string}]
|
||||
|
||||
def pil_to_base64(image):
|
||||
"""Convert PIL Image or bytes to base64 string"""
|
||||
if isinstance(image, bytes):
|
||||
# If it's already bytes, just encode to base64
|
||||
return base64.b64encode(image).decode("utf-8")
|
||||
else:
|
||||
# If it's a PIL Image, convert it
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
686
mm_agents/dart_gui_agent.py
Normal file
686
mm_agents/dart_gui_agent.py
Normal file
@@ -0,0 +1,686 @@
|
||||
"""
|
||||
Dart Agent - Custom agent for GUI automation using Dart models
|
||||
Based on UITARSAgent structure but using Dart-specific utilities and prompts
|
||||
"""
|
||||
import ast
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Any
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
import backoff
|
||||
import openai
|
||||
import requests
|
||||
from requests.exceptions import SSLError
|
||||
from google.api_core.exceptions import (
|
||||
BadRequest,
|
||||
InternalServerError,
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
)
|
||||
|
||||
# Import Dart-specific utilities and prompts
|
||||
from mm_agents.dart_gui.utils import (
|
||||
pil_to_base64,
|
||||
parse_action_to_structure_output,
|
||||
parsing_response_to_pyautogui_code,
|
||||
parse_action,
|
||||
escape_single_quotes,
|
||||
round_by_factor,
|
||||
ceil_by_factor,
|
||||
floor_by_factor,
|
||||
linear_resize,
|
||||
smart_resize,
|
||||
add_box_token,
|
||||
IMAGE_FACTOR,
|
||||
MIN_PIXELS,
|
||||
MAX_PIXELS,
|
||||
MAX_RATIO,
|
||||
FINISH_WORD,
|
||||
WAIT_WORD,
|
||||
ENV_FAIL_WORD,
|
||||
CALL_USER
|
||||
)
|
||||
|
||||
from mm_agents.dart_gui.prompts import (
|
||||
COMPUTER_USE_PROMPT,
|
||||
COMPUTER_USE_PROMPT_WITH_CALL_USER,
|
||||
UITARS_ACTION_SPACE,
|
||||
UITARS_CALL_USR_ACTION_SPACE,
|
||||
UITARS_USR_PROMPT_THOUGHT,
|
||||
UITARS_USR_PROMPT_NOTHOUGHT
|
||||
)
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
class DartAgent:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
runtime_conf: Dict,
|
||||
platform="ubuntu",
|
||||
max_tokens=1000,
|
||||
top_p=0.9,
|
||||
top_k=1.0,
|
||||
temperature=0.0,
|
||||
action_space="pyautogui",
|
||||
observation_type="screenshot",
|
||||
max_trajectory_length=50,
|
||||
model_type="qwen25vl",
|
||||
**kwargs
|
||||
):
|
||||
self.model = model
|
||||
self.platform = platform
|
||||
self.action_space = action_space
|
||||
self.observation_type = observation_type
|
||||
self.max_trajectory_length = max_trajectory_length
|
||||
self.model_type = model_type
|
||||
self.runtime_conf = runtime_conf
|
||||
|
||||
# Extract runtime configuration parameters
|
||||
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
|
||||
self.top_p = self.runtime_conf.get("top_p", top_p)
|
||||
self.top_k = self.runtime_conf.get("top_k", top_k)
|
||||
self.temperature = self.runtime_conf.get("temperature", temperature)
|
||||
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
|
||||
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
|
||||
self.input_swap = self.runtime_conf.get("input_swap", False)
|
||||
self.language = self.runtime_conf.get("language", "English")
|
||||
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
|
||||
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
|
||||
self.history_n = self.runtime_conf.get("history_n", 5)
|
||||
|
||||
# Dart specific configurations
|
||||
self.max_images = self.runtime_conf.get("max_images", 5)
|
||||
self.max_texts = self.runtime_conf.get("max_texts", 35)
|
||||
|
||||
# Initialize OpenAI client - use Dart API if provided
|
||||
dart_api_key = self.runtime_conf.get("dart_api_key", "")
|
||||
dart_base_url = self.runtime_conf.get("dart_base_url", "")
|
||||
|
||||
if dart_base_url:
|
||||
# 检查是否为直接的生成端点(包含 /generate)
|
||||
if '/generate' in dart_base_url:
|
||||
# 直接使用提供的 URL,不添加 /v1
|
||||
logger.info(f"使用直接生成端点: {dart_base_url}")
|
||||
self.dart_direct_url = dart_base_url
|
||||
self.vlm = None # 不使用 OpenAI 客户端
|
||||
else:
|
||||
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
|
||||
if not dart_base_url.endswith('/v1'):
|
||||
dart_base_url = dart_base_url.rstrip('/') + '/v1'
|
||||
|
||||
self.vlm = OpenAI(
|
||||
base_url=dart_base_url,
|
||||
api_key=dart_api_key,
|
||||
)
|
||||
self.dart_direct_url = None
|
||||
else:
|
||||
# Fallback to environment variables
|
||||
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
|
||||
if base_url:
|
||||
if '/generate' in base_url:
|
||||
# 直接生成端点
|
||||
self.dart_direct_url = base_url
|
||||
self.vlm = None
|
||||
else:
|
||||
if not base_url.endswith('/v1'):
|
||||
base_url = base_url.rstrip('/') + '/v1'
|
||||
self.vlm = OpenAI(
|
||||
base_url=base_url,
|
||||
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
|
||||
)
|
||||
self.dart_direct_url = None
|
||||
else:
|
||||
self.vlm = None
|
||||
self.dart_direct_url = None
|
||||
|
||||
# Initialize trajectory storage - similar to trajectory_runner.py
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
# Message handling similar to trajectory_runner.py
|
||||
self.base_messages = [] # for model client (with base64 images)
|
||||
self.base_messages_for_save = [] # for storage (with file paths)
|
||||
self.prompt_dialogue = [] # for model client
|
||||
self.save_dialogue = [] # for storage
|
||||
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
|
||||
self.image_refs = [] # record image position
|
||||
|
||||
# All image paths storage - to keep track of all images even when trimmed
|
||||
self.all_image_paths = []
|
||||
|
||||
# Current screenshot file path for proper saving
|
||||
self.current_screenshot_path = None
|
||||
|
||||
# Configure prompt and action space based on mode
|
||||
if self.infer_mode == "dart_mode":
|
||||
self.prompt_action_space = UITARS_ACTION_SPACE
|
||||
self.prompt_template = COMPUTER_USE_PROMPT
|
||||
else:
|
||||
# For qwen2vl_user mode
|
||||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||||
if self.prompt_style == "qwen2vl_user":
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
elif self.prompt_style == "qwen2vl_no_thought":
|
||||
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
|
||||
else:
|
||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||
|
||||
self.action_parse_res_factor = 1000
|
||||
|
||||
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
|
||||
|
||||
def reset(self, runtime_logger=None):
|
||||
"""Reset the agent state"""
|
||||
self.thoughts = []
|
||||
self.actions = []
|
||||
self.observations = []
|
||||
self.history_images = []
|
||||
self.history_responses = []
|
||||
|
||||
# Reset message handling
|
||||
self.base_messages = []
|
||||
self.base_messages_for_save = []
|
||||
self.prompt_dialogue = []
|
||||
self.save_dialogue = []
|
||||
self.save_dialogue_full = []
|
||||
self.image_refs = []
|
||||
self.all_image_paths = []
|
||||
self.current_screenshot_path = None
|
||||
|
||||
logger.info("DartAgent reset")
|
||||
|
||||
def set_base_messages(self, instruction: str):
|
||||
"""Initialize base messages similar to task_loader.py"""
|
||||
system_prompt = COMPUTER_USE_PROMPT
|
||||
|
||||
self.base_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": system_prompt.format(
|
||||
instruction=instruction,
|
||||
language=self.language
|
||||
)
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Copy for save version
|
||||
from copy import deepcopy
|
||||
self.base_messages_for_save = deepcopy(self.base_messages)
|
||||
|
||||
def set_current_screenshot_path(self, screenshot_path: str):
|
||||
"""Set the current screenshot file path for proper saving"""
|
||||
self.current_screenshot_path = screenshot_path
|
||||
|
||||
def predict(
|
||||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||||
) -> tuple:
|
||||
"""
|
||||
Predict the next action(s) based on the current observation.
|
||||
Returns: (response_text, actions_list)
|
||||
"""
|
||||
# Initialize base messages if not set
|
||||
if not self.base_messages:
|
||||
self.set_base_messages(instruction)
|
||||
|
||||
# Store current observation
|
||||
self._add_observation(obs)
|
||||
|
||||
# For first step, set the first frame
|
||||
if len(self.observations) == 1:
|
||||
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
|
||||
else:
|
||||
# For subsequent steps, add the new image to dialogue
|
||||
# This represents the result of the previous action
|
||||
self._add_image(obs["screenshot"], self.current_screenshot_path)
|
||||
|
||||
# Build prompt messages (base_messages + prompt_dialogue)
|
||||
messages = self._build_messages()
|
||||
|
||||
# Call model to get response
|
||||
prediction = self._call_model(messages)
|
||||
if prediction is None:
|
||||
return "client error", ["DONE"]
|
||||
|
||||
# Store response and parse actions
|
||||
self._add_text(prediction)
|
||||
|
||||
# Parse response to actions
|
||||
try:
|
||||
image_size = self._get_current_image_size()
|
||||
actions = self._parse_and_convert_actions(prediction, image_size)
|
||||
|
||||
# Check for terminal actions
|
||||
terminal_action = self._check_terminal_actions(actions)
|
||||
if terminal_action:
|
||||
self.actions.append(actions)
|
||||
return prediction, [terminal_action]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Parsing action error: {prediction}, error: {e}")
|
||||
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
|
||||
|
||||
self.actions.append(actions)
|
||||
# Check max steps
|
||||
if len(self.history_responses) >= self.max_trajectory_length:
|
||||
actions = ["FAIL"]
|
||||
|
||||
return prediction, actions
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.constant,
|
||||
(
|
||||
# General exceptions
|
||||
SSLError,
|
||||
# OpenAI exceptions
|
||||
openai.RateLimitError,
|
||||
openai.BadRequestError,
|
||||
openai.InternalServerError,
|
||||
# Google exceptions
|
||||
InvalidArgument,
|
||||
ResourceExhausted,
|
||||
InternalServerError,
|
||||
BadRequest,
|
||||
),
|
||||
interval=30,
|
||||
max_tries=10,
|
||||
)
|
||||
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
|
||||
"""Predict with backoff for rate limiting and temporary errors"""
|
||||
return self.predict(instruction, obs, last_action_after_obs)
|
||||
|
||||
def get_trajectory(self) -> List[Dict]:
|
||||
"""Get the current trajectory for saving"""
|
||||
trajectory = []
|
||||
for i in range(len(self.observations)):
|
||||
trajectory.append({
|
||||
"observation": self.observations[i],
|
||||
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
|
||||
"action": self.actions[i] if i < len(self.actions) else []
|
||||
})
|
||||
return trajectory
|
||||
|
||||
def get_full_messages(self) -> List[Dict]:
|
||||
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
|
||||
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
|
||||
full_messages = []
|
||||
|
||||
# Add base messages (system prompt and initial user message)
|
||||
full_messages.extend(self.base_messages_for_save)
|
||||
|
||||
# Add dialogue messages (user images + assistant responses) with all images
|
||||
full_messages.extend(self.save_dialogue_full)
|
||||
|
||||
return full_messages
|
||||
|
||||
def get_all_image_paths(self) -> List[str]:
|
||||
"""Get all image paths that have been used throughout the conversation"""
|
||||
return self.all_image_paths.copy()
|
||||
|
||||
# ========== Private Methods ==========
|
||||
|
||||
def _validate_trajectory(self):
|
||||
"""Validate trajectory consistency"""
|
||||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||||
self.thoughts
|
||||
), "The number of observations and actions should be the same."
|
||||
|
||||
def _add_observation(self, obs: Dict):
|
||||
"""Process observation and add to history"""
|
||||
# Store observation
|
||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||
base64_image = obs["screenshot"]
|
||||
try:
|
||||
# Handle accessibility tree if needed
|
||||
linearized_accessibility_tree = None
|
||||
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
|
||||
# For now, we'll skip accessibility tree processing in Dart mode
|
||||
linearized_accessibility_tree = None
|
||||
except:
|
||||
linearized_accessibility_tree = None
|
||||
|
||||
if self.observation_type == "screenshot_a11y_tree":
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": linearized_accessibility_tree,
|
||||
})
|
||||
else:
|
||||
self.observations.append({
|
||||
"screenshot": base64_image,
|
||||
"accessibility_tree": None
|
||||
})
|
||||
else:
|
||||
raise ValueError("Invalid observation_type type: " + self.observation_type)
|
||||
|
||||
|
||||
def _build_messages(self) -> List[Dict]:
|
||||
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
|
||||
return self.base_messages + self.prompt_dialogue
|
||||
|
||||
def _call_model(self, messages: List[Dict]) -> str:
|
||||
"""Call model with retry logic"""
|
||||
try_times = 3
|
||||
while try_times > 0:
|
||||
try:
|
||||
# 如果使用直接生成端点
|
||||
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
|
||||
prediction = self._call_direct_generate_endpoint(messages)
|
||||
else:
|
||||
# 使用标准 OpenAI 客户端
|
||||
response = self.vlm.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
frequency_penalty=1,
|
||||
max_tokens=self.max_tokens,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p
|
||||
)
|
||||
prediction = response.choices[0].message.content
|
||||
|
||||
logger.info(f"Model response: {prediction}")
|
||||
return prediction
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error when fetching response from client: {e}")
|
||||
try_times -= 1
|
||||
if try_times <= 0:
|
||||
logger.error("Reach max retry times to fetch response from client")
|
||||
return None
|
||||
return None
|
||||
|
||||
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
|
||||
"""直接调用生成端点"""
|
||||
try:
|
||||
|
||||
|
||||
# 构建请求数据
|
||||
payload = {
|
||||
"messages": messages,
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"frequency_penalty": 1
|
||||
}
|
||||
|
||||
# 添加 API key 到 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
|
||||
}
|
||||
|
||||
|
||||
# 重试机制:最多重试3次,每次推理60秒
|
||||
max_retries = 3
|
||||
response = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"尝试第 {attempt + 1} 次请求...")
|
||||
response = requests.post(
|
||||
self.dart_direct_url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=60
|
||||
)
|
||||
response.raise_for_status()
|
||||
break # 成功则跳出重试循环
|
||||
except Exception as e:
|
||||
logger.warning(f"第 {attempt + 1} 次请求失败: {e}")
|
||||
if attempt == max_retries - 1: # 最后一次重试失败
|
||||
logger.error(f"所有 {max_retries} 次重试都失败了")
|
||||
raise e
|
||||
else:
|
||||
logger.info(f"等待后重试...")
|
||||
import time
|
||||
time.sleep(2) # 等待2秒后重试
|
||||
|
||||
# 解析响应
|
||||
result = response.json()
|
||||
|
||||
# 尝试多种可能的响应格式
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
# OpenAI 兼容格式
|
||||
return result['choices'][0]['message']['content']
|
||||
elif 'response' in result:
|
||||
# 简单的 response 字段
|
||||
return result['response']
|
||||
elif 'text' in result:
|
||||
# text 字段
|
||||
return result['text']
|
||||
elif 'content' in result:
|
||||
# content 字段
|
||||
return result['content']
|
||||
else:
|
||||
# 如果找不到标准字段,返回整个响应的字符串
|
||||
logger.warning(f"未知的响应格式: {result}")
|
||||
return str(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"直接端点调用失败: {e}")
|
||||
raise e
|
||||
|
||||
def _add_text(self, assistant_txt: str):
|
||||
"""Add text response to history - similar to trajectory_runner.py"""
|
||||
self.history_responses.append(assistant_txt)
|
||||
self.thoughts.append(assistant_txt)
|
||||
|
||||
# Add to dialogue similar to trajectory_runner._add_text
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": add_box_token(assistant_txt)
|
||||
}
|
||||
self.prompt_dialogue.append(msg)
|
||||
self.save_dialogue.append(msg)
|
||||
self.save_dialogue_full.append(msg)
|
||||
self._trim()
|
||||
|
||||
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
|
||||
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
|
||||
self.base_messages[1]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
|
||||
}
|
||||
)
|
||||
|
||||
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
|
||||
if frame_path:
|
||||
first_frame_path = frame_path
|
||||
elif self.current_screenshot_path:
|
||||
first_frame_path = self.current_screenshot_path
|
||||
else:
|
||||
first_frame_path = "first_frame.png"
|
||||
|
||||
# Store in all_image_paths
|
||||
self.all_image_paths.append(first_frame_path)
|
||||
|
||||
self.base_messages_for_save[1]["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": first_frame_path
|
||||
}
|
||||
)
|
||||
|
||||
self.image_refs.append(
|
||||
{"source": "base", "msg_idx": 1,
|
||||
"content_idx": len(self.base_messages[1]["content"]) - 1}
|
||||
)
|
||||
|
||||
def _add_image(self, img_bytes: bytes, frame_path: str = None):
|
||||
"""Add image to dialogue - similar to trajectory_runner._add_image"""
|
||||
self.prompt_dialogue.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
|
||||
}]
|
||||
})
|
||||
|
||||
# Use actual frame path if provided, otherwise use current_screenshot_path
|
||||
if frame_path:
|
||||
image_url = frame_path
|
||||
elif self.current_screenshot_path:
|
||||
image_url = self.current_screenshot_path
|
||||
else:
|
||||
# Fallback to a placeholder - this should rarely happen in practice
|
||||
image_url = f"frame_{len(self.save_dialogue)}.png"
|
||||
|
||||
# Store in all_image_paths for complete record
|
||||
self.all_image_paths.append(image_url)
|
||||
|
||||
# Add to save_dialogue (trimmed version)
|
||||
self.save_dialogue.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": image_url
|
||||
}]
|
||||
})
|
||||
|
||||
# Add to save_dialogue_full (complete version - never trimmed)
|
||||
self.save_dialogue_full.append({
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": image_url
|
||||
}]
|
||||
})
|
||||
|
||||
self.image_refs.append(
|
||||
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
|
||||
"content_idx": None}
|
||||
)
|
||||
|
||||
self._trim()
|
||||
|
||||
def _trim(self):
|
||||
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
|
||||
img_cnt = len(self.image_refs)
|
||||
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
|
||||
|
||||
while img_cnt > self.max_images or txt_cnt > self.max_texts:
|
||||
# 图片超限:最早一张
|
||||
if img_cnt > self.max_images:
|
||||
ref = self.image_refs.pop(0)
|
||||
if ref["source"] == "base":
|
||||
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
|
||||
else: # dialogue 图
|
||||
self._remove_dialogue_msg(ref["msg_idx"])
|
||||
img_cnt -= 1
|
||||
continue
|
||||
|
||||
# 文本超限:最早 assistant 文本
|
||||
if txt_cnt > self.max_texts:
|
||||
for i, m in enumerate(self.prompt_dialogue):
|
||||
if m["role"] == "assistant":
|
||||
self._remove_dialogue_msg(i)
|
||||
txt_cnt -= 1
|
||||
break
|
||||
|
||||
def _remove_dialogue_msg(self, idx: int):
|
||||
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
|
||||
self.prompt_dialogue.pop(idx)
|
||||
self.save_dialogue.pop(idx)
|
||||
# Note: save_dialogue_full is never trimmed, so we don't remove from it
|
||||
|
||||
# 更新 image_refs
|
||||
self.image_refs = [
|
||||
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
|
||||
else None # 同一条被删掉的图引用直接丢弃
|
||||
for r in self.image_refs
|
||||
]
|
||||
self.image_refs = [
|
||||
(
|
||||
{**r, "msg_idx": r["msg_idx"] - 1}
|
||||
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
|
||||
else r
|
||||
)
|
||||
for r in self.image_refs
|
||||
if r # 剔除 None
|
||||
]
|
||||
|
||||
def _get_current_image_size(self) -> tuple:
|
||||
"""Get current image size for coordinate conversion"""
|
||||
if len(self.observations) > 0:
|
||||
try:
|
||||
current_image_bytes = self.observations[-1]["screenshot"]
|
||||
if isinstance(current_image_bytes, bytes):
|
||||
current_image = Image.open(BytesIO(current_image_bytes))
|
||||
return (current_image.height, current_image.width)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting image size: {e}")
|
||||
|
||||
# Fallback to default screen size
|
||||
return (1080, 1920)
|
||||
|
||||
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
|
||||
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
|
||||
image_height, image_width = image_size
|
||||
|
||||
# Parse the response to structured actions
|
||||
parsed_responses = parse_action_to_structure_output(
|
||||
prediction,
|
||||
factor=self.action_parse_res_factor,
|
||||
origin_resized_height=image_height,
|
||||
origin_resized_width=image_width,
|
||||
model_type=self.model_type,
|
||||
max_pixels=self.max_pixels,
|
||||
min_pixels=self.min_pixels
|
||||
)
|
||||
|
||||
# Convert parsed responses to pyautogui actions
|
||||
actions = []
|
||||
for parsed_response in parsed_responses:
|
||||
try:
|
||||
pyautogui_code = parsing_response_to_pyautogui_code(
|
||||
parsed_response,
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
input_swap=self.input_swap
|
||||
)
|
||||
|
||||
|
||||
|
||||
actions.append(pyautogui_code)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating pyautogui code: {e}")
|
||||
actions.append("FAIL")
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
|
||||
def _check_terminal_actions(self, actions: List[str]) -> str:
|
||||
"""Check if any action is terminal and return appropriate code"""
|
||||
for action in actions:
|
||||
if isinstance(action, dict) and "action_type" in action:
|
||||
action_type = action["action_type"]
|
||||
if action_type == FINISH_WORD:
|
||||
return "DONE"
|
||||
elif action_type == WAIT_WORD:
|
||||
return "WAIT"
|
||||
elif action_type == ENV_FAIL_WORD:
|
||||
return "FAIL"
|
||||
elif action_type == CALL_USER:
|
||||
return "FAIL"
|
||||
return None
|
||||
Reference in New Issue
Block a user