Files
sci-gui-agent-benchmark/mm_agents/dart_gui_agent.py
2025-11-07 21:50:01 +08:00

687 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Dart Agent - Custom agent for GUI automation using Dart models
Based on UITARSAgent structure but using Dart-specific utilities and prompts
"""
import ast
import base64
import logging
import math
import os
import re
import time
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
from openai import OpenAI
import backoff
import openai
import requests
from requests.exceptions import SSLError
from google.api_core.exceptions import (
BadRequest,
InternalServerError,
InvalidArgument,
ResourceExhausted,
)
# Import Dart-specific utilities and prompts
from mm_agents.dart_gui.utils import (
pil_to_base64,
parse_action_to_structure_output,
parsing_response_to_pyautogui_code,
parse_action,
escape_single_quotes,
round_by_factor,
ceil_by_factor,
floor_by_factor,
linear_resize,
smart_resize,
add_box_token,
IMAGE_FACTOR,
MIN_PIXELS,
MAX_PIXELS,
MAX_RATIO,
FINISH_WORD,
WAIT_WORD,
ENV_FAIL_WORD,
CALL_USER
)
from mm_agents.dart_gui.prompts import (
COMPUTER_USE_PROMPT,
COMPUTER_USE_PROMPT_WITH_CALL_USER,
UITARS_ACTION_SPACE,
UITARS_CALL_USR_ACTION_SPACE,
UITARS_USR_PROMPT_THOUGHT,
UITARS_USR_PROMPT_NOTHOUGHT
)
logger = logging.getLogger("desktopenv.agent")
class DartAgent:
def __init__(
self,
model: str,
runtime_conf: Dict,
platform="ubuntu",
max_tokens=1000,
top_p=0.9,
top_k=1.0,
temperature=0.0,
action_space="pyautogui",
observation_type="screenshot",
max_trajectory_length=50,
model_type="qwen25vl",
**kwargs
):
self.model = model
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.model_type = model_type
self.runtime_conf = runtime_conf
# Extract runtime configuration parameters
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
self.top_p = self.runtime_conf.get("top_p", top_p)
self.top_k = self.runtime_conf.get("top_k", top_k)
self.temperature = self.runtime_conf.get("temperature", temperature)
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
self.input_swap = self.runtime_conf.get("input_swap", False)
self.language = self.runtime_conf.get("language", "English")
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
self.history_n = self.runtime_conf.get("history_n", 5)
# Dart specific configurations
self.max_images = self.runtime_conf.get("max_images", 5)
self.max_texts = self.runtime_conf.get("max_texts", 35)
# Initialize OpenAI client - use Dart API if provided
dart_api_key = self.runtime_conf.get("dart_api_key", "")
dart_base_url = self.runtime_conf.get("dart_base_url", "")
if dart_base_url:
# 检查是否为直接的生成端点(包含 /generate
if '/generate' in dart_base_url:
# 直接使用提供的 URL不添加 /v1
logger.info(f"使用直接生成端点: {dart_base_url}")
self.dart_direct_url = dart_base_url
self.vlm = None # 不使用 OpenAI 客户端
else:
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
if not dart_base_url.endswith('/v1'):
dart_base_url = dart_base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=dart_base_url,
api_key=dart_api_key,
)
self.dart_direct_url = None
else:
# Fallback to environment variables
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
if base_url:
if '/generate' in base_url:
# 直接生成端点
self.dart_direct_url = base_url
self.vlm = None
else:
if not base_url.endswith('/v1'):
base_url = base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=base_url,
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
)
self.dart_direct_url = None
else:
self.vlm = None
self.dart_direct_url = None
# Initialize trajectory storage - similar to trajectory_runner.py
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Message handling similar to trajectory_runner.py
self.base_messages = [] # for model client (with base64 images)
self.base_messages_for_save = [] # for storage (with file paths)
self.prompt_dialogue = [] # for model client
self.save_dialogue = [] # for storage
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
self.image_refs = [] # record image position
# All image paths storage - to keep track of all images even when trimmed
self.all_image_paths = []
# Current screenshot file path for proper saving
self.current_screenshot_path = None
# Configure prompt and action space based on mode
if self.infer_mode == "dart_mode":
self.prompt_action_space = UITARS_ACTION_SPACE
self.prompt_template = COMPUTER_USE_PROMPT
else:
# For qwen2vl_user mode
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
if self.prompt_style == "qwen2vl_user":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
else:
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
self.action_parse_res_factor = 1000
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
def reset(self, runtime_logger=None):
"""Reset the agent state"""
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Reset message handling
self.base_messages = []
self.base_messages_for_save = []
self.prompt_dialogue = []
self.save_dialogue = []
self.save_dialogue_full = []
self.image_refs = []
self.all_image_paths = []
self.current_screenshot_path = None
logger.info("DartAgent reset")
def set_base_messages(self, instruction: str):
"""Initialize base messages similar to task_loader.py"""
system_prompt = COMPUTER_USE_PROMPT
self.base_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language=self.language
)
}
]
}
]
# Copy for save version
from copy import deepcopy
self.base_messages_for_save = deepcopy(self.base_messages)
def set_current_screenshot_path(self, screenshot_path: str):
"""Set the current screenshot file path for proper saving"""
self.current_screenshot_path = screenshot_path
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
) -> tuple:
"""
Predict the next action(s) based on the current observation.
Returns: (response_text, actions_list)
"""
# Initialize base messages if not set
if not self.base_messages:
self.set_base_messages(instruction)
# Store current observation
self._add_observation(obs)
# For first step, set the first frame
if len(self.observations) == 1:
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
else:
# For subsequent steps, add the new image to dialogue
# This represents the result of the previous action
self._add_image(obs["screenshot"], self.current_screenshot_path)
# Build prompt messages (base_messages + prompt_dialogue)
messages = self._build_messages()
# Call model to get response
prediction = self._call_model(messages)
if prediction is None:
return "client error", ["DONE"]
# Store response and parse actions
self._add_text(prediction)
# Parse response to actions
try:
image_size = self._get_current_image_size()
actions = self._parse_and_convert_actions(prediction, image_size)
# Check for terminal actions
terminal_action = self._check_terminal_actions(actions)
if terminal_action:
self.actions.append(actions)
return prediction, [terminal_action]
except Exception as e:
logger.error(f"Parsing action error: {prediction}, error: {e}")
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
self.actions.append(actions)
# Check max steps
if len(self.history_responses) >= self.max_trajectory_length:
actions = ["FAIL"]
return prediction, actions
@backoff.on_exception(
backoff.constant,
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
),
interval=30,
max_tries=10,
)
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
"""Predict with backoff for rate limiting and temporary errors"""
return self.predict(instruction, obs, last_action_after_obs)
def get_trajectory(self) -> List[Dict]:
"""Get the current trajectory for saving"""
trajectory = []
for i in range(len(self.observations)):
trajectory.append({
"observation": self.observations[i],
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
"action": self.actions[i] if i < len(self.actions) else []
})
return trajectory
def get_full_messages(self) -> List[Dict]:
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
full_messages = []
# Add base messages (system prompt and initial user message)
full_messages.extend(self.base_messages_for_save)
# Add dialogue messages (user images + assistant responses) with all images
full_messages.extend(self.save_dialogue_full)
return full_messages
def get_all_image_paths(self) -> List[str]:
"""Get all image paths that have been used throughout the conversation"""
return self.all_image_paths.copy()
# ========== Private Methods ==========
def _validate_trajectory(self):
"""Validate trajectory consistency"""
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
def _add_observation(self, obs: Dict):
"""Process observation and add to history"""
# Store observation
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = obs["screenshot"]
try:
# Handle accessibility tree if needed
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
# For now, we'll skip accessibility tree processing in Dart mode
linearized_accessibility_tree = None
except:
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree,
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type)
def _build_messages(self) -> List[Dict]:
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
return self.base_messages + self.prompt_dialogue
def _call_model(self, messages: List[Dict]) -> str:
"""Call model with retry logic"""
try_times = 3
while try_times > 0:
try:
# 如果使用直接生成端点
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
prediction = self._call_direct_generate_endpoint(messages)
else:
# 使用标准 OpenAI 客户端
response = self.vlm.chat.completions.create(
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
prediction = response.choices[0].message.content
logger.info(f"Model response: {prediction}")
return prediction
except Exception as e:
logger.error(f"Error when fetching response from client: {e}")
try_times -= 1
if try_times <= 0:
logger.error("Reach max retry times to fetch response from client")
return None
return None
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
"""直接调用生成端点"""
try:
# 构建请求数据
payload = {
"messages": messages,
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": 1
}
# 添加 API key 到 headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
}
# 重试机制最多重试3次每次推理60秒
max_retries = 3
response = None
for attempt in range(max_retries):
try:
logger.info(f"尝试第 {attempt + 1} 次请求...")
response = requests.post(
self.dart_direct_url,
json=payload,
headers=headers,
timeout=60
)
response.raise_for_status()
break # 成功则跳出重试循环
except Exception as e:
logger.warning(f"{attempt + 1} 次请求失败: {e}")
if attempt == max_retries - 1: # 最后一次重试失败
logger.error(f"所有 {max_retries} 次重试都失败了")
raise e
else:
logger.info(f"等待后重试...")
import time
time.sleep(2) # 等待2秒后重试
# 解析响应
result = response.json()
# 尝试多种可能的响应格式
if 'choices' in result and len(result['choices']) > 0:
# OpenAI 兼容格式
return result['choices'][0]['message']['content']
elif 'response' in result:
# 简单的 response 字段
return result['response']
elif 'text' in result:
# text 字段
return result['text']
elif 'content' in result:
# content 字段
return result['content']
else:
# 如果找不到标准字段,返回整个响应的字符串
logger.warning(f"未知的响应格式: {result}")
return str(result)
except Exception as e:
logger.error(f"直接端点调用失败: {e}")
raise e
def _add_text(self, assistant_txt: str):
"""Add text response to history - similar to trajectory_runner.py"""
self.history_responses.append(assistant_txt)
self.thoughts.append(assistant_txt)
# Add to dialogue similar to trajectory_runner._add_text
msg = {
"role": "assistant",
"content": add_box_token(assistant_txt)
}
self.prompt_dialogue.append(msg)
self.save_dialogue.append(msg)
self.save_dialogue_full.append(msg)
self._trim()
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
self.base_messages[1]["content"].append(
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
}
)
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
if frame_path:
first_frame_path = frame_path
elif self.current_screenshot_path:
first_frame_path = self.current_screenshot_path
else:
first_frame_path = "first_frame.png"
# Store in all_image_paths
self.all_image_paths.append(first_frame_path)
self.base_messages_for_save[1]["content"].append(
{
"type": "image_url",
"image_url": first_frame_path
}
)
self.image_refs.append(
{"source": "base", "msg_idx": 1,
"content_idx": len(self.base_messages[1]["content"]) - 1}
)
def _add_image(self, img_bytes: bytes, frame_path: str = None):
"""Add image to dialogue - similar to trajectory_runner._add_image"""
self.prompt_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
}]
})
# Use actual frame path if provided, otherwise use current_screenshot_path
if frame_path:
image_url = frame_path
elif self.current_screenshot_path:
image_url = self.current_screenshot_path
else:
# Fallback to a placeholder - this should rarely happen in practice
image_url = f"frame_{len(self.save_dialogue)}.png"
# Store in all_image_paths for complete record
self.all_image_paths.append(image_url)
# Add to save_dialogue (trimmed version)
self.save_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
# Add to save_dialogue_full (complete version - never trimmed)
self.save_dialogue_full.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
self.image_refs.append(
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
"content_idx": None}
)
self._trim()
def _trim(self):
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
img_cnt = len(self.image_refs)
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
while img_cnt > self.max_images or txt_cnt > self.max_texts:
# 图片超限:最早一张
if img_cnt > self.max_images:
ref = self.image_refs.pop(0)
if ref["source"] == "base":
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
else: # dialogue 图
self._remove_dialogue_msg(ref["msg_idx"])
img_cnt -= 1
continue
# 文本超限:最早 assistant 文本
if txt_cnt > self.max_texts:
for i, m in enumerate(self.prompt_dialogue):
if m["role"] == "assistant":
self._remove_dialogue_msg(i)
txt_cnt -= 1
break
def _remove_dialogue_msg(self, idx: int):
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
self.prompt_dialogue.pop(idx)
self.save_dialogue.pop(idx)
# Note: save_dialogue_full is never trimmed, so we don't remove from it
# 更新 image_refs
self.image_refs = [
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
else None # 同一条被删掉的图引用直接丢弃
for r in self.image_refs
]
self.image_refs = [
(
{**r, "msg_idx": r["msg_idx"] - 1}
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
else r
)
for r in self.image_refs
if r # 剔除 None
]
def _get_current_image_size(self) -> tuple:
"""Get current image size for coordinate conversion"""
if len(self.observations) > 0:
try:
current_image_bytes = self.observations[-1]["screenshot"]
if isinstance(current_image_bytes, bytes):
current_image = Image.open(BytesIO(current_image_bytes))
return (current_image.height, current_image.width)
except Exception as e:
logger.warning(f"Error getting image size: {e}")
# Fallback to default screen size
return (1080, 1920)
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
image_height, image_width = image_size
# Parse the response to structured actions
parsed_responses = parse_action_to_structure_output(
prediction,
factor=self.action_parse_res_factor,
origin_resized_height=image_height,
origin_resized_width=image_width,
model_type=self.model_type,
max_pixels=self.max_pixels,
min_pixels=self.min_pixels
)
# Convert parsed responses to pyautogui actions
actions = []
for parsed_response in parsed_responses:
try:
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
image_height=image_height,
image_width=image_width,
input_swap=self.input_swap
)
actions.append(pyautogui_code)
except Exception as e:
logger.error(f"Error generating pyautogui code: {e}")
actions.append("FAIL")
return actions
def _check_terminal_actions(self, actions: List[str]) -> str:
"""Check if any action is terminal and return appropriate code"""
for action in actions:
if isinstance(action, dict) and "action_type" in action:
action_type = action["action_type"]
if action_type == FINISH_WORD:
return "DONE"
elif action_type == WAIT_WORD:
return "WAIT"
elif action_type == ENV_FAIL_WORD:
return "FAIL"
elif action_type == CALL_USER:
return "FAIL"
return None