feat/dart_gui (#371)

This commit is contained in:
Pengxiang-Li
2025-11-07 21:50:01 +08:00
committed by GitHub
parent 6d43dbc532
commit 00b6468eb7
8 changed files with 2499 additions and 4 deletions

View File

@@ -0,0 +1,161 @@
COMPUTER_USE_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
COMPUTER_USE_PROMPT_WITH_CALL_USER = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
- My computer's password is 'password', feel free to use it when you need sudo rights.
## User Instruction
{instruction}
"""
UITARS_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
"""
UITARS_CALL_USR_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
"""
UITARS_NORMAL_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
"""
UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## User Instruction
{instruction}
"""
UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
FAILURE_INDICATORS = [
# Direct inability expressions
"无法", "不能", "不可以", "做不到", "实现不了", "完成不了","没法",
# Regret/apology expressions
"遗憾", "抱歉", "很抱歉", "非常抱歉", "对不起",
# Not supported/available
"不直接支持", "不支持", "不提供", "不具备", "没有权限", "权限不足", "不在这里面","不符合",#"不存在",
# Cannot access/handle
"无权访问", "访问不了", "处理不了", "操作不了", "执行不了", "没找到", "空空如也",
# Not possible/feasible
"不可能", "无法实现", "实现不了", "办不到", "做不了","找不到","存在技术限制","没有找到","没有内置",
# System limitations
"超出范围", "不在我的能力范围", "能力有限", "功能限制","没有成功","没成功","硬件的问题",
# Refusal indicators
"拒绝", "不允许", "禁止", "不合适", "不恰当",
# Trying Restart
"从头开始", "藏在", "浪费时间","一个更合理的思路","正确的方向","没有意义",#, "重新","重启",
]

View File

@@ -0,0 +1,202 @@
import asyncio
from typing import List, Optional, Union, Dict, Any
import json
import os
import hashlib
from pathlib import Path
from omegaconf import DictConfig
from dataclasses import dataclass, asdict
import copy
import logging
import random
from prompts import COMPUTER_USE_PROMPT, COMPUTER_USE_PROMPT_WITH_CALL_USER
from log_config import setup_logging
# 设置统一的日志系统
setup_logging()
logger = logging.getLogger(__name__)
class TaskLoader:
def __init__(self, task_cfg: DictConfig, storage_root):
self.task_file = Path(task_cfg.task_file)
#self.task_root = Path(task_cfg.task_root)
self.osworld_root = Path(task_cfg.osworld_root)
self._latest_sha: Optional[str] = None
self.storage_root = storage_root
self.resume = task_cfg.resume
def poll_for_tasks(self) -> List[Dict]:
"""find new tasks json file
return list of TaskInfo dict if there is new json
else return []
"""
self._maybe_refresh_dataset()
tasks_list = [task.to_dict() for task in self._tasks]
random.shuffle(tasks_list)
return tasks_list
def _maybe_refresh_dataset_bak(self):
# check new json
latest_json = self._find_latest_json()
if latest_json is None:
return False # no json file
sha = self._calc_sha1(latest_json)
if sha == self._latest_sha:
return False # no change
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
self._latest_sha = sha
logger.info(f"当前任务文件: {str(latest_json)}")
logger.info(f"任务总数: {len(raw_tasks)}")
return True
def _maybe_refresh_dataset(self):
latest_json = self.task_file
print("Current tasks file: ", str(latest_json))
with open(latest_json) as f:
data = json.load(f)
raw_tasks = [
{"task_type": task_type, "task_id": task_id}
for task_type, task_ids in data.items()
for task_id in task_ids
]
if self.resume:
# 过滤已完成或类型不匹配的任务
filtered_tasks = []
storage_root = Path(self.storage_root)
for raw in raw_tasks:
task_id = str(raw["task_id"])
task_type_expected = raw["task_type"]
# 找到所有以 task_id 开头的子目录(允许有多个版本)
candidate_dirs = [
d for d in storage_root.iterdir()
if d.is_dir() and d.name.startswith(task_id)
]
# 默认认为任务未完成
task_finished = False
for d in candidate_dirs:
cfg_path = d / "task_config.json"
if not cfg_path.exists():
print("找不到config文件")
continue
try:
with cfg_path.open("r", encoding="utf-8") as cf:
cfg = json.load(cf)
except Exception:
print("配置损坏,忽略此目录")
continue
# 3.1 task_type 不同 => 不是同一个任务,直接跳过这目录
if cfg.get("raw", {}).get("task_type") != task_type_expected:
continue
# 3.2 task_type 相同,检查 reward.txt
if (d / "reward.txt").exists():
task_finished = True
break # 已找到完成记录,无需再看其他目录
if not task_finished:
filtered_tasks.append(raw)
self._tasks = [build_task(raw, self.osworld_root) for raw in filtered_tasks]
print(f"Total number of tasks: {len(raw_tasks)}, Remained:{len(filtered_tasks)}")
else:
self._tasks = [build_task(raw, self.osworld_root) for raw in raw_tasks]
print(f"Total number of tasks: {len(raw_tasks)}")
return True
def _find_latest_json(self) -> Optional[Path]:
files = list(self.task_root.glob("*.json"))
return max(files, key=lambda p: p.stat().st_mtime) if files else None
@staticmethod
def _calc_sha1(fp: Path, chunk_size=2<<20) -> str:
h = hashlib.sha1()
with fp.open("rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
h.update(chunk)
return h.hexdigest()
@dataclass
class TaskInfo:
messages: List
instruction: str
task_config: Dict
def to_dict(self):
return asdict(self)
def build_task(raw: Dict, osworld_root: Path, use_call_user: bool = False) -> TaskInfo:
task_type = raw["task_type"]
task_id = raw["task_id"]
task_path = os.path.join(osworld_root, task_type, task_id + ".json")
with open(task_path) as f:
task_data = json.load(f)
task_data["raw"] = {
"task_type": task_type,
"task_id": task_id
}
instruction = task_data["instruction"]
if "human-ground-truth" in task_data and "single-action" in task_data["human-ground-truth"]:
plan = task_data["human-ground-truth"]["single-action"]
plan_text = "\n".join(plan)
instruction = instruction.strip() + "\nHere is an instruction to help you complete the task: \n" + plan_text
system_prompt = COMPUTER_USE_PROMPT if not use_call_user else COMPUTER_USE_PROMPT_WITH_CALL_USER
messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language="English"
)}
]
}
]
return TaskInfo(
messages = messages,
instruction = instruction,
task_config = task_data
)

511
mm_agents/dart_gui/utils.py Normal file
View File

@@ -0,0 +1,511 @@
import ast
import base64
import logging
import math
import re
import xml.etree.ElementTree as ET
from io import BytesIO
from typing import Dict, List
import numpy as np
import openai
from openai import OpenAI
from PIL import Image
from requests.exceptions import SSLError
from mm_agents.dart_gui.prompts import FAILURE_INDICATORS
# 设置日志系统
logger = logging.getLogger(__name__)
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
pure_text_settings = ["a11y_tree"]
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
# 定义一个函数来解析每个 action
def parse_action(action_str):
try:
# 解析字符串为 AST 节点
node = ast.parse(action_str, mode='eval')
# 确保节点是一个表达式
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
# 获取表达式的主体
call = node.body
# 确保主体是一个函数调用
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# 获取函数名
if isinstance(call.func, ast.Name):
func_name = call.func.id
elif isinstance(call.func, ast.Attribute):
func_name = call.func.attr
else:
func_name = None
# 获取关键字参数
kwargs = {}
for kw in call.keywords:
key = kw.arg
# 处理不同类型的值,这里假设都是常量
if isinstance(kw.value, ast.Constant):
value = kw.value.value
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
value = kw.value.s
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
except Exception as e:
logger.error(f"Failed to parse action '{action_str}': {e}")
return None
def escape_single_quotes(text):
# 匹配未转义的单引号(不匹配 \\'
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def linear_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
if width * height > max_pixels:
"""
如果图片超过/低于像素限制则计算一个缩放因子resize_factor使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
"""
resize_factor = math.sqrt(max_pixels / (width * height))
width, height = int(width * resize_factor), int(height * resize_factor)
if width * height < min_pixels:
resize_factor = math.sqrt(min_pixels / (width * height))
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
return height, width
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
text = text.strip()
if model_type == "qwen25vl":
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
# 正则表达式匹配 Action 字符串
if text.startswith("Thought:"):
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
elif text.startswith("Reflection:"):
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Reflection: "
elif text.startswith("Action_Summary:"):
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
thought_hint = "Action_Summary: "
else:
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
thought_hint = "Thought: "
reflection, thought = None, None
thought_match = re.search(thought_pattern, text, re.DOTALL)
if thought_match:
if len(thought_match.groups()) == 1:
thought = thought_match.group(1).strip()
elif len(thought_match.groups()) == 2:
thought = thought_match.group(2).strip()
reflection = thought_match.group(1).strip()
assert "Action:" in text
action_str = text.split("Action:")[-1]
tmp_all_action = action_str.split("\n\n")
all_action = []
for action_str in tmp_all_action:
if "type(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
if "finished(content" in action_str:
# 正则表达式匹配 content 中的字符串并转义单引号
def escape_quotes(match):
content = match.group(1) # 获取 content 的值
return content
# 使用正则表达式进行替换
pattern = r"finished\(content='(.*?)'\)" # 匹配 type(content='...')
content = re.sub(pattern, escape_quotes, action_str)
# 处理字符串
action_str = escape_single_quotes(content)
action_str = "finished(content='" + action_str + "')"
all_action.append(action_str)
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
actions = []
for action_instance, raw_str in zip(parsed_actions, all_action):
if action_instance == None:
logger.error(f"Action can't parse: {raw_str}")
# raise ValueError(f"Action can't parse: {raw_str}")
continue
action_type = action_instance["function"]
params = action_instance["args"]
# import pdb; pdb.set_trace()
action_inputs = {}
for param_name, param in params.items():
if param == "": continue
param = param.lstrip() # 去掉引号和多余的空格
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
action_inputs[param_name.strip()] = param
if "start_box" in param_name or "end_box" in param_name:
ori_box = param
# Remove parentheses and split the string by commas
numbers = ori_box.replace("(", "").replace(")", "").split(",")
# Convert to float and scale by 1000
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
if model_type == "qwen25vl":
float_numbers = []
for num_idx, num in enumerate(numbers):
num = float(num)
if (num_idx + 1) % 2 == 0:
float_numbers.append(float(num/smart_resize_height))
else:
float_numbers.append(float(num/smart_resize_width))
else:
float_numbers = [float(num) / factor for num in numbers]
if len(float_numbers) == 2:
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
# import pdb; pdb.set_trace()
actions.append(
{
"reflection": reflection,
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
})
return actions
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
'''
将M模型的输出解析为OSWorld中的action生成pyautogui代码字符串
参数:
response: 包含模型输出的字典,结构类似于:
{
"action_type": "hotkey",
"action_inputs": {
"hotkey": "v ctrl",
"start_box": None,
"end_box": None
}
}
返回:
生成的pyautogui代码字符串
'''
pyautogui_code = "import pyautogui\nimport time\n"
if isinstance(responses, dict):
responses = [responses]
for response_id, response in enumerate(responses):
if "observation" in response:
observation = response["observation"]
else:
observation = ""
if "thought" in response:
thought = response["thought"]
else:
thought = ""
if response_id == 0:
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
else:
pyautogui_code += "\ntime.sleep(1)\n"
action_dict = response
response_text = action_dict.get("text", "")
action_type = action_dict.get("action_type")
action_inputs = action_dict.get("action_inputs", {})
if action_type == "hotkey":
# Parsing hotkey action
if "key" in action_inputs:
hotkey = action_inputs.get("key", "")
else:
hotkey = action_inputs.get("hotkey", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
if hotkey:
# Handle other hotkeys
keys = hotkey.split() # Split the keys by space
convert_keys = []
for key in keys:
if key == "space":
key = ' '
convert_keys.append(key)
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
elif action_type == "press":
# Parsing press action
if "key" in action_inputs:
key_to_press = action_inputs.get("key", "")
else:
key_to_press = action_inputs.get("press", "")
if hotkey == "arrowleft":
hotkey = "left"
elif hotkey == "arrowright":
hotkey = "right"
elif hotkey == "arrowup":
hotkey = "up"
elif hotkey == "arrowdown":
hotkey = "down"
elif hotkey == "space":
hotkey = " "
if key_to_press:
# Simulate pressing a single key
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
elif action_type == "keyup":
key_to_up = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
elif action_type == "keydown":
key_to_down = action_inputs.get("key", "")
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
elif action_type == "type":
# Parsing typing action using clipboard
content = action_inputs.get("content", "")
content = escape_single_quotes(content)
stripped_content = content
if content.endswith("\n") or content.endswith("\\n"):
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
if content:
if input_swap:
pyautogui_code += "\nimport pyperclip"
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
pyautogui_code += "\npyautogui.hotkey('ctrl', 'v')"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
else:
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
pyautogui_code += "\ntime.sleep(0.5)\n"
if content.endswith("\n") or content.endswith("\\n"):
pyautogui_code += "\npyautogui.press('enter')"
elif action_type in ["drag", "select"]:
# Parsing drag or select action based on start and end_boxes
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
sx = round(float((x1 + x2) / 2) * image_width, 3)
sy = round(float((y1 + y2) / 2) * image_height, 3)
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
ex = round(float((x1 + x2) / 2) * image_width, 3)
ey = round(float((y1 + y2) / 2) * image_height, 3)
pyautogui_code += (
f"\npyautogui.moveTo({sx}, {sy})\n"
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
)
elif action_type == "scroll":
# Parsing scroll action
start_box = action_inputs.get("start_box")
if start_box:
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
# # 先点对应区域,再滚动
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
else:
x = None
y = None
direction = action_inputs.get("direction", "")
if x == None:
if "up" in direction.lower():
pyautogui_code += "\npyautogui.scroll(5)"
elif "down" in direction.lower():
pyautogui_code += "\npyautogui.scroll(-5)"
else:
if "up" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
elif "down" in direction.lower():
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
# Parsing mouse click actions
start_box = action_inputs.get("start_box")
start_box = str(start_box)
if start_box:
start_box = eval(start_box)
if start_box is None:
logger.warning(f"[Warning] start_box is None and wired condition:\n{action_inputs}")
if len(start_box) == 4:
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
elif len(start_box) == 2:
x1, y1 = start_box
x2 = x1
y2 = y1
x = round(float((x1 + x2) / 2) * image_width, 3)
y = round(float((y1 + y2) / 2) * image_height, 3)
if action_type == "left_single" or action_type == "click":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
elif action_type == "left_double":
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
elif action_type == "right_single":
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
elif action_type == "hover":
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
elif action_type in ["finished"]:
pyautogui_code = "DONE"
print(f"FINISHED:response_text: {response_text}")
print(f"FINISHED:response: {str(response)}")
for failure_indicator in FAILURE_INDICATORS:
if failure_indicator in response_text:
pyautogui_code = "FAIL"
break
elif action_type in ["wait"]:
pyautogui_code = "WAIT"
elif action_type in ["call_user"]:
pyautogui_code = "FAIL"
else:
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
return pyautogui_code
def add_box_token(input_string):
# Step 1: Split the string into individual actions
if "Action: " in input_string and "start_box=" in input_string:
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
# Step 2: Extract coordinates (start_box or end_box) using regex
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action # Start with the original action
for coord_type, x, y in coordinates:
# Convert x and y to integers
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
processed_actions.append(updated_action)
# Step 5: Reconstruct the final string
final_string = suffix + "\n\n".join(processed_actions)
else:
final_string = input_string
# print(f"Input string: {input_string}")
# print(f"Final string: {final_string}")
return [{"type": "text", "text": final_string}]
def pil_to_base64(image):
"""Convert PIL Image or bytes to base64 string"""
if isinstance(image, bytes):
# If it's already bytes, just encode to base64
return base64.b64encode(image).decode("utf-8")
else:
# If it's a PIL Image, convert it
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")

686
mm_agents/dart_gui_agent.py Normal file
View File

@@ -0,0 +1,686 @@
"""
Dart Agent - Custom agent for GUI automation using Dart models
Based on UITARSAgent structure but using Dart-specific utilities and prompts
"""
import ast
import base64
import logging
import math
import os
import re
import time
from io import BytesIO
from typing import Dict, List, Any
from PIL import Image
from openai import OpenAI
import backoff
import openai
import requests
from requests.exceptions import SSLError
from google.api_core.exceptions import (
BadRequest,
InternalServerError,
InvalidArgument,
ResourceExhausted,
)
# Import Dart-specific utilities and prompts
from mm_agents.dart_gui.utils import (
pil_to_base64,
parse_action_to_structure_output,
parsing_response_to_pyautogui_code,
parse_action,
escape_single_quotes,
round_by_factor,
ceil_by_factor,
floor_by_factor,
linear_resize,
smart_resize,
add_box_token,
IMAGE_FACTOR,
MIN_PIXELS,
MAX_PIXELS,
MAX_RATIO,
FINISH_WORD,
WAIT_WORD,
ENV_FAIL_WORD,
CALL_USER
)
from mm_agents.dart_gui.prompts import (
COMPUTER_USE_PROMPT,
COMPUTER_USE_PROMPT_WITH_CALL_USER,
UITARS_ACTION_SPACE,
UITARS_CALL_USR_ACTION_SPACE,
UITARS_USR_PROMPT_THOUGHT,
UITARS_USR_PROMPT_NOTHOUGHT
)
logger = logging.getLogger("desktopenv.agent")
class DartAgent:
def __init__(
self,
model: str,
runtime_conf: Dict,
platform="ubuntu",
max_tokens=1000,
top_p=0.9,
top_k=1.0,
temperature=0.0,
action_space="pyautogui",
observation_type="screenshot",
max_trajectory_length=50,
model_type="qwen25vl",
**kwargs
):
self.model = model
self.platform = platform
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.model_type = model_type
self.runtime_conf = runtime_conf
# Extract runtime configuration parameters
self.max_tokens = self.runtime_conf.get("max_tokens", max_tokens)
self.top_p = self.runtime_conf.get("top_p", top_p)
self.top_k = self.runtime_conf.get("top_k", top_k)
self.temperature = self.runtime_conf.get("temperature", temperature)
self.infer_mode = self.runtime_conf.get("infer_mode", "dart_mode")
self.prompt_style = self.runtime_conf.get("prompt_style", "dart_style")
self.input_swap = self.runtime_conf.get("input_swap", False)
self.language = self.runtime_conf.get("language", "English")
self.max_pixels = self.runtime_conf.get("max_pixels", MAX_PIXELS)
self.min_pixels = self.runtime_conf.get("min_pixels", MIN_PIXELS)
self.history_n = self.runtime_conf.get("history_n", 5)
# Dart specific configurations
self.max_images = self.runtime_conf.get("max_images", 5)
self.max_texts = self.runtime_conf.get("max_texts", 35)
# Initialize OpenAI client - use Dart API if provided
dart_api_key = self.runtime_conf.get("dart_api_key", "")
dart_base_url = self.runtime_conf.get("dart_base_url", "")
if dart_base_url:
# 检查是否为直接的生成端点(包含 /generate
if '/generate' in dart_base_url:
# 直接使用提供的 URL不添加 /v1
logger.info(f"使用直接生成端点: {dart_base_url}")
self.dart_direct_url = dart_base_url
self.vlm = None # 不使用 OpenAI 客户端
else:
# 传统的 OpenAI 兼容端点,确保以 /v1 结尾
if not dart_base_url.endswith('/v1'):
dart_base_url = dart_base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=dart_base_url,
api_key=dart_api_key,
)
self.dart_direct_url = None
else:
# Fallback to environment variables
base_url = os.environ.get('DART_API_URL', os.environ.get('DOUBAO_API_URL'))
if base_url:
if '/generate' in base_url:
# 直接生成端点
self.dart_direct_url = base_url
self.vlm = None
else:
if not base_url.endswith('/v1'):
base_url = base_url.rstrip('/') + '/v1'
self.vlm = OpenAI(
base_url=base_url,
api_key=os.environ.get('DART_API_KEY', os.environ.get('DOUBAO_API_KEY')),
)
self.dart_direct_url = None
else:
self.vlm = None
self.dart_direct_url = None
# Initialize trajectory storage - similar to trajectory_runner.py
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Message handling similar to trajectory_runner.py
self.base_messages = [] # for model client (with base64 images)
self.base_messages_for_save = [] # for storage (with file paths)
self.prompt_dialogue = [] # for model client
self.save_dialogue = [] # for storage
self.save_dialogue_full = [] # for full storage (保存所有图片路径)
self.image_refs = [] # record image position
# All image paths storage - to keep track of all images even when trimmed
self.all_image_paths = []
# Current screenshot file path for proper saving
self.current_screenshot_path = None
# Configure prompt and action space based on mode
if self.infer_mode == "dart_mode":
self.prompt_action_space = UITARS_ACTION_SPACE
self.prompt_template = COMPUTER_USE_PROMPT
else:
# For qwen2vl_user mode
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
if self.prompt_style == "qwen2vl_user":
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
elif self.prompt_style == "qwen2vl_no_thought":
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
else:
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
self.action_parse_res_factor = 1000
logger.info(f"Initialized DartAgent with model: {self.model}, mode: {self.infer_mode}")
def reset(self, runtime_logger=None):
"""Reset the agent state"""
self.thoughts = []
self.actions = []
self.observations = []
self.history_images = []
self.history_responses = []
# Reset message handling
self.base_messages = []
self.base_messages_for_save = []
self.prompt_dialogue = []
self.save_dialogue = []
self.save_dialogue_full = []
self.image_refs = []
self.all_image_paths = []
self.current_screenshot_path = None
logger.info("DartAgent reset")
def set_base_messages(self, instruction: str):
"""Initialize base messages similar to task_loader.py"""
system_prompt = COMPUTER_USE_PROMPT
self.base_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": [
{
"type": "text",
"text": system_prompt.format(
instruction=instruction,
language=self.language
)
}
]
}
]
# Copy for save version
from copy import deepcopy
self.base_messages_for_save = deepcopy(self.base_messages)
def set_current_screenshot_path(self, screenshot_path: str):
"""Set the current screenshot file path for proper saving"""
self.current_screenshot_path = screenshot_path
def predict(
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
) -> tuple:
"""
Predict the next action(s) based on the current observation.
Returns: (response_text, actions_list)
"""
# Initialize base messages if not set
if not self.base_messages:
self.set_base_messages(instruction)
# Store current observation
self._add_observation(obs)
# For first step, set the first frame
if len(self.observations) == 1:
self._set_first_frame(obs["screenshot"], self.current_screenshot_path)
else:
# For subsequent steps, add the new image to dialogue
# This represents the result of the previous action
self._add_image(obs["screenshot"], self.current_screenshot_path)
# Build prompt messages (base_messages + prompt_dialogue)
messages = self._build_messages()
# Call model to get response
prediction = self._call_model(messages)
if prediction is None:
return "client error", ["DONE"]
# Store response and parse actions
self._add_text(prediction)
# Parse response to actions
try:
image_size = self._get_current_image_size()
actions = self._parse_and_convert_actions(prediction, image_size)
# Check for terminal actions
terminal_action = self._check_terminal_actions(actions)
if terminal_action:
self.actions.append(actions)
return prediction, [terminal_action]
except Exception as e:
logger.error(f"Parsing action error: {prediction}, error: {e}")
return f"Parsing action error: {prediction}, error: {e}", ["DONE"]
self.actions.append(actions)
# Check max steps
if len(self.history_responses) >= self.max_trajectory_length:
actions = ["FAIL"]
return prediction, actions
@backoff.on_exception(
backoff.constant,
(
# General exceptions
SSLError,
# OpenAI exceptions
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
# Google exceptions
InvalidArgument,
ResourceExhausted,
InternalServerError,
BadRequest,
),
interval=30,
max_tries=10,
)
def predict_with_backoff(self, instruction: str, obs: Dict, last_action_after_obs: Dict = None):
"""Predict with backoff for rate limiting and temporary errors"""
return self.predict(instruction, obs, last_action_after_obs)
def get_trajectory(self) -> List[Dict]:
"""Get the current trajectory for saving"""
trajectory = []
for i in range(len(self.observations)):
trajectory.append({
"observation": self.observations[i],
"thought": self.thoughts[i] if i < len(self.thoughts) else "",
"action": self.actions[i] if i < len(self.actions) else []
})
return trajectory
def get_full_messages(self) -> List[Dict]:
"""Get the complete conversation messages for saving (including base messages and dialogue)"""
# Combine base_messages_for_save with save_dialogue_full to get complete conversation
full_messages = []
# Add base messages (system prompt and initial user message)
full_messages.extend(self.base_messages_for_save)
# Add dialogue messages (user images + assistant responses) with all images
full_messages.extend(self.save_dialogue_full)
return full_messages
def get_all_image_paths(self) -> List[str]:
"""Get all image paths that have been used throughout the conversation"""
return self.all_image_paths.copy()
# ========== Private Methods ==========
def _validate_trajectory(self):
"""Validate trajectory consistency"""
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
self.thoughts
), "The number of observations and actions should be the same."
def _add_observation(self, obs: Dict):
"""Process observation and add to history"""
# Store observation
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = obs["screenshot"]
try:
# Handle accessibility tree if needed
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree" and "accessibility_tree" in obs:
# For now, we'll skip accessibility tree processing in Dart mode
linearized_accessibility_tree = None
except:
linearized_accessibility_tree = None
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree,
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type)
def _build_messages(self) -> List[Dict]:
"""Build messages for model API call - similar to trajectory_runner._build_messages"""
return self.base_messages + self.prompt_dialogue
def _call_model(self, messages: List[Dict]) -> str:
"""Call model with retry logic"""
try_times = 3
while try_times > 0:
try:
# 如果使用直接生成端点
if hasattr(self, 'dart_direct_url') and self.dart_direct_url:
prediction = self._call_direct_generate_endpoint(messages)
else:
# 使用标准 OpenAI 客户端
response = self.vlm.chat.completions.create(
model=self.model,
messages=messages,
frequency_penalty=1,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p
)
prediction = response.choices[0].message.content
logger.info(f"Model response: {prediction}")
return prediction
except Exception as e:
logger.error(f"Error when fetching response from client: {e}")
try_times -= 1
if try_times <= 0:
logger.error("Reach max retry times to fetch response from client")
return None
return None
def _call_direct_generate_endpoint(self, messages: List[Dict]) -> str:
"""直接调用生成端点"""
try:
# 构建请求数据
payload = {
"messages": messages,
"model": self.model,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"frequency_penalty": 1
}
# 添加 API key 到 headers
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.runtime_conf.get('dart_api_key', '')}"
}
# 重试机制最多重试3次每次推理60秒
max_retries = 3
response = None
for attempt in range(max_retries):
try:
logger.info(f"尝试第 {attempt + 1} 次请求...")
response = requests.post(
self.dart_direct_url,
json=payload,
headers=headers,
timeout=60
)
response.raise_for_status()
break # 成功则跳出重试循环
except Exception as e:
logger.warning(f"{attempt + 1} 次请求失败: {e}")
if attempt == max_retries - 1: # 最后一次重试失败
logger.error(f"所有 {max_retries} 次重试都失败了")
raise e
else:
logger.info(f"等待后重试...")
import time
time.sleep(2) # 等待2秒后重试
# 解析响应
result = response.json()
# 尝试多种可能的响应格式
if 'choices' in result and len(result['choices']) > 0:
# OpenAI 兼容格式
return result['choices'][0]['message']['content']
elif 'response' in result:
# 简单的 response 字段
return result['response']
elif 'text' in result:
# text 字段
return result['text']
elif 'content' in result:
# content 字段
return result['content']
else:
# 如果找不到标准字段,返回整个响应的字符串
logger.warning(f"未知的响应格式: {result}")
return str(result)
except Exception as e:
logger.error(f"直接端点调用失败: {e}")
raise e
def _add_text(self, assistant_txt: str):
"""Add text response to history - similar to trajectory_runner.py"""
self.history_responses.append(assistant_txt)
self.thoughts.append(assistant_txt)
# Add to dialogue similar to trajectory_runner._add_text
msg = {
"role": "assistant",
"content": add_box_token(assistant_txt)
}
self.prompt_dialogue.append(msg)
self.save_dialogue.append(msg)
self.save_dialogue_full.append(msg)
self._trim()
def _set_first_frame(self, obs_img: bytes, frame_path: str = None):
"""Set first frame in base_messages - similar to trajectory_runner._set_first_frame"""
self.base_messages[1]["content"].append(
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(obs_img)}
}
)
# Use actual frame path if provided, otherwise use current_screenshot_path or placeholder
if frame_path:
first_frame_path = frame_path
elif self.current_screenshot_path:
first_frame_path = self.current_screenshot_path
else:
first_frame_path = "first_frame.png"
# Store in all_image_paths
self.all_image_paths.append(first_frame_path)
self.base_messages_for_save[1]["content"].append(
{
"type": "image_url",
"image_url": first_frame_path
}
)
self.image_refs.append(
{"source": "base", "msg_idx": 1,
"content_idx": len(self.base_messages[1]["content"]) - 1}
)
def _add_image(self, img_bytes: bytes, frame_path: str = None):
"""Add image to dialogue - similar to trajectory_runner._add_image"""
self.prompt_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": {"url": "data:image/png;base64," + pil_to_base64(img_bytes)}
}]
})
# Use actual frame path if provided, otherwise use current_screenshot_path
if frame_path:
image_url = frame_path
elif self.current_screenshot_path:
image_url = self.current_screenshot_path
else:
# Fallback to a placeholder - this should rarely happen in practice
image_url = f"frame_{len(self.save_dialogue)}.png"
# Store in all_image_paths for complete record
self.all_image_paths.append(image_url)
# Add to save_dialogue (trimmed version)
self.save_dialogue.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
# Add to save_dialogue_full (complete version - never trimmed)
self.save_dialogue_full.append({
"role": "user",
"content": [{
"type": "image_url",
"image_url": image_url
}]
})
self.image_refs.append(
{"source": "dialogue", "msg_idx": len(self.prompt_dialogue) - 1,
"content_idx": None}
)
self._trim()
def _trim(self):
"""Ensure image num ≤ max_images and assistant text num ≤ max_texts - similar to trajectory_runner._trim"""
img_cnt = len(self.image_refs)
txt_cnt = sum(m["role"] == "assistant" for m in self.prompt_dialogue)
while img_cnt > self.max_images or txt_cnt > self.max_texts:
# 图片超限:最早一张
if img_cnt > self.max_images:
ref = self.image_refs.pop(0)
if ref["source"] == "base":
self.base_messages[ref["msg_idx"]]["content"].pop(ref["content_idx"])
else: # dialogue 图
self._remove_dialogue_msg(ref["msg_idx"])
img_cnt -= 1
continue
# 文本超限:最早 assistant 文本
if txt_cnt > self.max_texts:
for i, m in enumerate(self.prompt_dialogue):
if m["role"] == "assistant":
self._remove_dialogue_msg(i)
txt_cnt -= 1
break
def _remove_dialogue_msg(self, idx: int):
"""Remove dialogue message and update refs - similar to trajectory_runner._remove_dialogue_msg"""
self.prompt_dialogue.pop(idx)
self.save_dialogue.pop(idx)
# Note: save_dialogue_full is never trimmed, so we don't remove from it
# 更新 image_refs
self.image_refs = [
r if not (r["source"] == "dialogue" and r["msg_idx"] == idx)
else None # 同一条被删掉的图引用直接丢弃
for r in self.image_refs
]
self.image_refs = [
(
{**r, "msg_idx": r["msg_idx"] - 1}
if r and r["source"] == "dialogue" and r["msg_idx"] > idx # idx后的图片索引均-1
else r
)
for r in self.image_refs
if r # 剔除 None
]
def _get_current_image_size(self) -> tuple:
"""Get current image size for coordinate conversion"""
if len(self.observations) > 0:
try:
current_image_bytes = self.observations[-1]["screenshot"]
if isinstance(current_image_bytes, bytes):
current_image = Image.open(BytesIO(current_image_bytes))
return (current_image.height, current_image.width)
except Exception as e:
logger.warning(f"Error getting image size: {e}")
# Fallback to default screen size
return (1080, 1920)
def _parse_and_convert_actions(self, prediction: str, image_size: tuple) -> List[str]:
"""Parse response and convert to pyautogui actions - similar to trajectory_runner._parse"""
image_height, image_width = image_size
# Parse the response to structured actions
parsed_responses = parse_action_to_structure_output(
prediction,
factor=self.action_parse_res_factor,
origin_resized_height=image_height,
origin_resized_width=image_width,
model_type=self.model_type,
max_pixels=self.max_pixels,
min_pixels=self.min_pixels
)
# Convert parsed responses to pyautogui actions
actions = []
for parsed_response in parsed_responses:
try:
pyautogui_code = parsing_response_to_pyautogui_code(
parsed_response,
image_height=image_height,
image_width=image_width,
input_swap=self.input_swap
)
actions.append(pyautogui_code)
except Exception as e:
logger.error(f"Error generating pyautogui code: {e}")
actions.append("FAIL")
return actions
def _check_terminal_actions(self, actions: List[str]) -> str:
"""Check if any action is terminal and return appropriate code"""
for action in actions:
if isinstance(action, dict) and "action_type" in action:
action_type = action["action_type"]
if action_type == FINISH_WORD:
return "DONE"
elif action_type == WAIT_WORD:
return "WAIT"
elif action_type == ENV_FAIL_WORD:
return "FAIL"
elif action_type == CALL_USER:
return "FAIL"
return None