* use aws pub ip * os task fix: set the default dim screen time to be 300s * add all the uitars agents: 1. run_multienv_uitars.py: Qwen2VL-based UITARS models 2. run_multienv_uitars15_v1.py: UITARS1.5-7B 3. run_multienv_uitars15_v2.py: SeedVL1.5 thining/non-thinking --------- Co-authored-by: Jiaqi <dengjiaqi@moonshot.cn>
763 lines
29 KiB
Python
763 lines
29 KiB
Python
import ast
|
||
import base64
|
||
import logging
|
||
import math
|
||
import re
|
||
import xml.etree.ElementTree as ET
|
||
from io import BytesIO
|
||
from typing import Dict, List
|
||
import os
|
||
import backoff
|
||
import numpy as np
|
||
from PIL import Image
|
||
from requests.exceptions import SSLError
|
||
import openai
|
||
from openai import OpenAI
|
||
from google.api_core.exceptions import (
|
||
BadRequest,
|
||
InternalServerError,
|
||
InvalidArgument,
|
||
ResourceExhausted,
|
||
)
|
||
|
||
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import (
|
||
filter_nodes,
|
||
)
|
||
from mm_agents.prompts import (
|
||
UITARS_ACTION_SPACE,
|
||
UITARS_CALL_USR_ACTION_SPACE,
|
||
UITARS_USR_PROMPT_NOTHOUGHT,
|
||
UITARS_USR_PROMPT_THOUGHT,
|
||
)
|
||
|
||
|
||
from loguru import logger
|
||
|
||
FINISH_WORD = "finished"
|
||
WAIT_WORD = "wait"
|
||
ENV_FAIL_WORD = "error_env"
|
||
CALL_USER = "call_user"
|
||
|
||
pure_text_settings = ["a11y_tree"]
|
||
|
||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||
attributes_ns_windows = "https://accessibility.windows.example.org/ns/attributes"
|
||
state_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/state"
|
||
state_ns_windows = "https://accessibility.windows.example.org/ns/state"
|
||
component_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/component"
|
||
component_ns_windows = "https://accessibility.windows.example.org/ns/component"
|
||
value_ns_ubuntu = "https://accessibility.ubuntu.example.org/ns/value"
|
||
value_ns_windows = "https://accessibility.windows.example.org/ns/value"
|
||
class_ns_windows = "https://accessibility.windows.example.org/ns/class"
|
||
# More namespaces defined in OSWorld, please check desktop_env/server/main.py
|
||
|
||
# 定义一个函数来解析每个 action
|
||
def parse_action(action_str):
|
||
try:
|
||
# 解析字符串为 AST 节点
|
||
node = ast.parse(action_str, mode='eval')
|
||
|
||
# 确保节点是一个表达式
|
||
if not isinstance(node, ast.Expression):
|
||
raise ValueError("Not an expression")
|
||
|
||
# 获取表达式的主体
|
||
call = node.body
|
||
|
||
# 确保主体是一个函数调用
|
||
if not isinstance(call, ast.Call):
|
||
raise ValueError("Not a function call")
|
||
|
||
# 获取函数名
|
||
if isinstance(call.func, ast.Name):
|
||
func_name = call.func.id
|
||
elif isinstance(call.func, ast.Attribute):
|
||
func_name = call.func.attr
|
||
else:
|
||
func_name = None
|
||
|
||
# 获取关键字参数
|
||
kwargs = {}
|
||
for kw in call.keywords:
|
||
key = kw.arg
|
||
# 处理不同类型的值,这里假设都是常量
|
||
if isinstance(kw.value, ast.Constant):
|
||
value = kw.value.value
|
||
elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python
|
||
value = kw.value.s
|
||
else:
|
||
value = None
|
||
kwargs[key] = value
|
||
|
||
return {
|
||
'function': func_name,
|
||
'args': kwargs
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"Failed to parse action '{action_str}': {e}")
|
||
return None
|
||
|
||
def escape_single_quotes(text):
|
||
# 匹配未转义的单引号(不匹配 \\')
|
||
pattern = r"(?<!\\)'"
|
||
return re.sub(pattern, r"\\'", text)
|
||
|
||
def parse_action_qwen2vl(text, factor, image_height, image_width):
|
||
text = text.strip()
|
||
# 正则表达式匹配 Action 字符串
|
||
if text.startswith("Thought:"):
|
||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||
thought_hint = "Thought: "
|
||
elif text.startswith("Reflection:"):
|
||
thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)"
|
||
thought_hint = "Reflection: "
|
||
elif text.startswith("Action_Summary:"):
|
||
thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)"
|
||
thought_hint = "Action_Summary: "
|
||
else:
|
||
# 修复:当没有明确的"Thought:"标识时,提取Action:之前的所有内容作为思考
|
||
thought_pattern = r"(.+?)(?=\s*Action:|$)"
|
||
thought_hint = ""
|
||
|
||
reflection, thought = None, None
|
||
thought_match = re.search(thought_pattern, text, re.DOTALL)
|
||
if thought_match:
|
||
if len(thought_match.groups()) == 1:
|
||
thought = thought_match.group(1).strip()
|
||
elif len(thought_match.groups()) == 2:
|
||
thought = thought_match.group(2).strip()
|
||
reflection = thought_match.group(1).strip()
|
||
assert "Action:" in text
|
||
action_str = text.split("Action:")[-1]
|
||
|
||
tmp_all_action = action_str.split("\n\n")
|
||
all_action = []
|
||
for action_str in tmp_all_action:
|
||
if "type(content" in action_str:
|
||
# 正则表达式匹配 content 中的字符串并转义单引号
|
||
def escape_quotes(match):
|
||
content = match.group(1) # 获取 content 的值
|
||
return content
|
||
|
||
# 使用正则表达式进行替换
|
||
pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...')
|
||
content = re.sub(pattern, escape_quotes, action_str)
|
||
|
||
# 处理字符串
|
||
action_str = escape_single_quotes(content)
|
||
action_str = "type(content='" + action_str + "')"
|
||
all_action.append(action_str)
|
||
|
||
parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action]
|
||
actions = []
|
||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||
if action_instance == None:
|
||
print(f"Action can't parse: {raw_str}")
|
||
continue
|
||
action_type = action_instance["function"]
|
||
params = action_instance["args"]
|
||
|
||
# import pdb; pdb.set_trace()
|
||
action_inputs = {}
|
||
for param_name, param in params.items():
|
||
if param == "": continue
|
||
param = param.lstrip() # 去掉引号和多余的空格
|
||
# 处理start_box或者end_box参数格式 '<bbox>x1 y1 x2 y2</bbox>'
|
||
action_inputs[param_name.strip()] = param
|
||
|
||
if "start_box" in param_name or "end_box" in param_name:
|
||
ori_box = param
|
||
# Remove parentheses and split the string by commas
|
||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||
|
||
# Convert to float and scale by 1000
|
||
float_numbers = [float(num) / factor for num in numbers]
|
||
if len(float_numbers) == 2:
|
||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||
action_inputs[param_name.strip()] = str(float_numbers)
|
||
|
||
# import pdb; pdb.set_trace()
|
||
actions.append({
|
||
"reflection": reflection,
|
||
"thought": thought,
|
||
"action_type": action_type,
|
||
"action_inputs": action_inputs,
|
||
"text": text
|
||
})
|
||
return actions
|
||
|
||
def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str:
|
||
'''
|
||
将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串
|
||
参数:
|
||
response: 包含模型输出的字典,结构类似于:
|
||
{
|
||
"action_type": "hotkey",
|
||
"action_inputs": {
|
||
"hotkey": "v ctrl",
|
||
"start_box": None,
|
||
"end_box": None
|
||
}
|
||
}
|
||
返回:
|
||
生成的pyautogui代码字符串
|
||
'''
|
||
|
||
pyautogui_code = f"import pyautogui\nimport time\n"
|
||
if isinstance(responses, dict):
|
||
responses = [responses]
|
||
for response_id, response in enumerate(responses):
|
||
if "observation" in response:
|
||
observation = response["observation"]
|
||
else:
|
||
observation = ""
|
||
|
||
if "thought" in response:
|
||
thought = response["thought"]
|
||
else:
|
||
thought = ""
|
||
|
||
if response_id == 0:
|
||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||
else:
|
||
pyautogui_code += f"\ntime.sleep(3)\n"
|
||
|
||
action_dict = response
|
||
action_type = action_dict.get("action_type")
|
||
action_inputs = action_dict.get("action_inputs", {})
|
||
|
||
if action_type == "hotkey":
|
||
# Parsing hotkey action
|
||
if "key" in action_inputs:
|
||
hotkey = action_inputs.get("key", "")
|
||
else:
|
||
hotkey = action_inputs.get("hotkey", "")
|
||
|
||
if hotkey:
|
||
# Handle other hotkeys
|
||
keys = hotkey.split() # Split the keys by space
|
||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})"
|
||
|
||
elif action_type == "type":
|
||
# Parsing typing action using clipboard
|
||
content = action_inputs.get("content", "")
|
||
content = escape_single_quotes(content)
|
||
if content:
|
||
if input_swap:
|
||
pyautogui_code += f"\nimport pyperclip"
|
||
pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
|
||
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||
if content.endswith("\n") or content.endswith("\\n"):
|
||
pyautogui_code += f"\npyautogui.press('enter')"
|
||
else:
|
||
pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
|
||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||
if content.endswith("\n") or content.endswith("\\n"):
|
||
pyautogui_code += f"\npyautogui.press('enter')"
|
||
|
||
|
||
elif action_type in ["drag", "select"]:
|
||
# Parsing drag or select action based on start and end_boxes
|
||
start_box = action_inputs.get("start_box")
|
||
end_box = action_inputs.get("end_box")
|
||
if start_box and end_box:
|
||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||
sx = round(float((x1 + x2) / 2) * image_width, 3)
|
||
sy = round(float((y1 + y2) / 2) * image_height, 3)
|
||
x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2]
|
||
ex = round(float((x1 + x2) / 2) * image_width, 3)
|
||
ey = round(float((y1 + y2) / 2) * image_height, 3)
|
||
pyautogui_code += (
|
||
f"\npyautogui.moveTo({sx}, {sy})\n"
|
||
f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n"
|
||
)
|
||
|
||
elif action_type == "scroll":
|
||
# Parsing scroll action
|
||
start_box = action_inputs.get("start_box")
|
||
if start_box:
|
||
x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2]
|
||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||
|
||
# # 先点对应区域,再滚动
|
||
# pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||
else:
|
||
x = None
|
||
y = None
|
||
direction = action_inputs.get("direction", "")
|
||
|
||
if x == None:
|
||
if "up" in direction.lower():
|
||
pyautogui_code += f"\npyautogui.scroll(5)"
|
||
elif "down" in direction.lower():
|
||
pyautogui_code += f"\npyautogui.scroll(-5)"
|
||
else:
|
||
if "up" in direction.lower():
|
||
pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})"
|
||
elif "down" in direction.lower():
|
||
pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})"
|
||
|
||
elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]:
|
||
# Parsing mouse click actions
|
||
start_box = action_inputs.get("start_box")
|
||
start_box = str(start_box)
|
||
if start_box:
|
||
start_box = eval(start_box)
|
||
if len(start_box) == 4:
|
||
x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2]
|
||
elif len(start_box) == 2:
|
||
x1, y1 = start_box
|
||
x2 = x1
|
||
y2 = y1
|
||
x = round(float((x1 + x2) / 2) * image_width, 3)
|
||
y = round(float((y1 + y2) / 2) * image_height, 3)
|
||
if action_type == "left_single" or action_type == "click":
|
||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')"
|
||
elif action_type == "left_double":
|
||
pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')"
|
||
elif action_type == "right_single":
|
||
pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')"
|
||
elif action_type == "hover":
|
||
pyautogui_code += f"\npyautogui.moveTo({x}, {y})"
|
||
|
||
elif action_type in ["finished"]:
|
||
pyautogui_code = f"DONE"
|
||
|
||
else:
|
||
pyautogui_code += f"\n# Unrecognized action type: {action_type}"
|
||
|
||
return pyautogui_code
|
||
|
||
def pil_to_base64(image):
|
||
buffer = BytesIO()
|
||
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
|
||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||
|
||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||
|
||
if platform == "ubuntu":
|
||
_attributes_ns = attributes_ns_ubuntu
|
||
_state_ns = state_ns_ubuntu
|
||
_component_ns = component_ns_ubuntu
|
||
_value_ns = value_ns_ubuntu
|
||
elif platform == "windows":
|
||
_attributes_ns = attributes_ns_windows
|
||
_state_ns = state_ns_windows
|
||
_component_ns = component_ns_windows
|
||
_value_ns = value_ns_windows
|
||
else:
|
||
raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'")
|
||
|
||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||
linearized_accessibility_tree = [
|
||
"tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)"
|
||
]
|
||
|
||
# Linearize the accessibility tree nodes into a table format
|
||
for node in filtered_nodes:
|
||
if node.text:
|
||
text = (
|
||
node.text
|
||
if '"' not in node.text
|
||
else '"{:}"'.format(node.text.replace('"', '""'))
|
||
)
|
||
|
||
elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith(
|
||
"EditWrapper"
|
||
) and node.get("{{{:}}}value".format(_value_ns)):
|
||
node_text = node.get("{{{:}}}value".format(_value_ns), "")
|
||
text = (
|
||
node_text
|
||
if '"' not in node_text
|
||
else '"{:}"'.format(node_text.replace('"', '""'))
|
||
)
|
||
else:
|
||
text = '""'
|
||
|
||
linearized_accessibility_tree.append(
|
||
"{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format(
|
||
node.tag,
|
||
node.get("name", ""),
|
||
text,
|
||
(
|
||
node.get("{{{:}}}class".format(_attributes_ns), "")
|
||
if platform == "ubuntu"
|
||
else node.get("{{{:}}}class".format(class_ns_windows), "")
|
||
),
|
||
node.get("{{{:}}}description".format(_attributes_ns), ""),
|
||
node.get("{{{:}}}screencoord".format(_component_ns), ""),
|
||
node.get("{{{:}}}size".format(_component_ns), ""),
|
||
)
|
||
)
|
||
|
||
return "\n".join(linearized_accessibility_tree)
|
||
|
||
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
||
# enc = tiktoken.encoding_for_model("gpt-4")
|
||
# tokens = enc.encode(linearized_accessibility_tree)
|
||
# if len(tokens) > max_tokens:
|
||
# linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
||
# linearized_accessibility_tree += "[...]\n"
|
||
return linearized_accessibility_tree
|
||
|
||
class UITARSAgent:
|
||
def __init__(
|
||
self,
|
||
model: str,
|
||
platform="ubuntu",
|
||
max_tokens=1000,
|
||
top_p=0.9,
|
||
top_k=1.0,
|
||
temperature=0.0,
|
||
action_space="pyautogui",
|
||
observation_type="screenshot_a11y_tree",
|
||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||
max_trajectory_length=50,
|
||
a11y_tree_max_tokens=10000,
|
||
runtime_conf: dict = {
|
||
"infer_mode": "qwen2vl_user",
|
||
"prompt_style": "qwen2vl_user",
|
||
"input_swap": True,
|
||
"language": "Chinese",
|
||
"max_steps": 50,
|
||
"history_n": 5,
|
||
"screen_height": 1080,
|
||
"screen_width": 1920
|
||
}
|
||
):
|
||
self.model = model
|
||
self.platform = platform
|
||
self.max_tokens = max_tokens
|
||
self.top_p = top_p
|
||
self.top_k = top_k
|
||
self.temperature = temperature
|
||
self.action_space = action_space
|
||
self.observation_type = observation_type
|
||
self.max_trajectory_length = max_trajectory_length
|
||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||
self.runtime_conf = runtime_conf
|
||
self.vlm = OpenAI(
|
||
base_url=os.environ['DOUBAO_API_URL'],
|
||
api_key=os.environ['DOUBAO_API_KEY'],
|
||
) # should replace with your UI-TARS server api
|
||
self.infer_mode = self.runtime_conf["infer_mode"]
|
||
self.prompt_style = self.runtime_conf["prompt_style"]
|
||
self.input_swap = self.runtime_conf["input_swap"]
|
||
self.language = self.runtime_conf["language"]
|
||
self.max_steps = max_trajectory_length
|
||
|
||
self.thoughts = []
|
||
self.actions = []
|
||
self.observations = []
|
||
self.history_images = []
|
||
self.history_responses = []
|
||
|
||
self.prompt_action_space = UITARS_ACTION_SPACE
|
||
self.customize_action_parser = parse_action_qwen2vl
|
||
self.action_parse_res_factor = 1000
|
||
if self.infer_mode == "qwen2vl_user":
|
||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||
|
||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||
|
||
if self.prompt_style == "qwen2vl_user":
|
||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||
|
||
elif self.prompt_style == "qwen2vl_no_thought":
|
||
self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT
|
||
|
||
|
||
if "history_n" in self.runtime_conf:
|
||
self.history_n = self.runtime_conf["history_n"]
|
||
else:
|
||
self.history_n = 5
|
||
|
||
def predict(
|
||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||
) -> List:
|
||
"""
|
||
Predict the next action(s) based on the current observation.
|
||
"""
|
||
|
||
# Append trajectory
|
||
# print(len(self.observations), len(self.actions), len(self.actions))
|
||
assert len(self.observations) == len(self.actions) and len(self.actions) == len(
|
||
self.thoughts
|
||
), "The number of observations and actions should be the same."
|
||
|
||
if len(self.observations) > self.max_trajectory_length:
|
||
if self.max_trajectory_length == 0:
|
||
_observations = []
|
||
_actions = []
|
||
_thoughts = []
|
||
else:
|
||
_observations = self.observations[-self.max_trajectory_length :]
|
||
_actions = self.actions[-self.max_trajectory_length :]
|
||
_thoughts = self.thoughts[-self.max_trajectory_length :]
|
||
else:
|
||
_observations = self.observations
|
||
_actions = self.actions
|
||
_thoughts = self.thoughts
|
||
|
||
|
||
if last_action_after_obs is not None and self.infer_mode == "double_image":
|
||
self.history_images.append(last_action_after_obs["screenshot"])
|
||
|
||
self.history_images.append(obs["screenshot"])
|
||
|
||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||
base64_image = obs["screenshot"]
|
||
try:
|
||
linearized_accessibility_tree = (
|
||
linearize_accessibility_tree(
|
||
accessibility_tree=obs["accessibility_tree"],
|
||
platform=self.platform,
|
||
)
|
||
if self.observation_type == "screenshot_a11y_tree"
|
||
else None
|
||
)
|
||
except:
|
||
linearized_accessibility_tree = None
|
||
# logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
||
|
||
if linearized_accessibility_tree:
|
||
linearized_accessibility_tree = trim_accessibility_tree(
|
||
linearized_accessibility_tree, self.a11y_tree_max_tokens
|
||
)
|
||
|
||
if self.observation_type == "screenshot_a11y_tree":
|
||
self.observations.append(
|
||
{
|
||
"screenshot": base64_image,
|
||
"accessibility_tree": linearized_accessibility_tree,
|
||
}
|
||
)
|
||
else:
|
||
self.observations.append(
|
||
{"screenshot": base64_image, "accessibility_tree": None}
|
||
)
|
||
|
||
else:
|
||
raise ValueError(
|
||
"Invalid observation_type type: " + self.observation_type
|
||
) # 1}}}
|
||
|
||
if self.infer_mode == "qwen2vl_user":
|
||
user_prompt = self.prompt_template.format(
|
||
instruction=instruction,
|
||
action_space=self.prompt_action_space,
|
||
language=self.language
|
||
)
|
||
elif self.infer_mode == "qwen2vl_no_thought":
|
||
user_prompt = self.prompt_template.format(
|
||
instruction=instruction
|
||
)
|
||
|
||
if len(self.history_images) > self.history_n:
|
||
self.history_images = self.history_images[-self.history_n:]
|
||
|
||
max_pixels = 2116800
|
||
min_pixels = 3136
|
||
messages, images = [], []
|
||
if isinstance(self.history_images, bytes):
|
||
self.history_images = [self.history_images]
|
||
elif isinstance(self.history_images, np.ndarray):
|
||
self.history_images = list(self.history_images)
|
||
elif isinstance(self.history_images, list):
|
||
pass
|
||
else:
|
||
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
|
||
max_image_nums_under_32k = int(32768*0.75/max_pixels*28*28)
|
||
if len(self.history_images) > max_image_nums_under_32k:
|
||
num_of_images = min(5, len(self.history_images))
|
||
max_pixels = int(32768*0.75) // num_of_images
|
||
|
||
for turn, image in enumerate(self.history_images):
|
||
if len(images) >= 5:
|
||
break
|
||
try:
|
||
image = Image.open(BytesIO(image))
|
||
except Exception as e:
|
||
raise RuntimeError(f"Error opening image: {e}")
|
||
|
||
if image.width * image.height > max_pixels:
|
||
"""
|
||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||
"""
|
||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||
image = image.resize((width, height))
|
||
if image.width * image.height < min_pixels:
|
||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
|
||
image = image.resize((width, height))
|
||
|
||
if image.mode != "RGB":
|
||
image = image.convert("RGB")
|
||
|
||
images.append(image)
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": [{"type": "text", "text": "You are a helpful assistant."}]
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [{"type": "text", "text": user_prompt}]
|
||
}
|
||
]
|
||
|
||
image_num = 0
|
||
if len(self.history_responses) > 0:
|
||
for history_idx, history_response in enumerate(self.history_responses):
|
||
# send at most history_n images to the model
|
||
if history_idx + self.history_n > len(self.history_responses):
|
||
|
||
cur_image = images[image_num]
|
||
encoded_string = pil_to_base64(cur_image)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||
})
|
||
image_num += 1
|
||
|
||
messages.append({
|
||
"role": "assistant",
|
||
"content": history_response
|
||
})
|
||
|
||
cur_image = images[image_num]
|
||
encoded_string = pil_to_base64(cur_image)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||
})
|
||
image_num += 1
|
||
|
||
else:
|
||
cur_image = images[image_num]
|
||
encoded_string = pil_to_base64(cur_image)
|
||
messages.append({
|
||
"role": "user",
|
||
"content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}]
|
||
})
|
||
image_num += 1
|
||
|
||
try_times = 3
|
||
while True:
|
||
if try_times <= 0:
|
||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||
return "client error", ["DONE"]
|
||
try:
|
||
|
||
response = self.vlm.chat.completions.create(
|
||
model=self.model,
|
||
messages=messages,
|
||
frequency_penalty=1,
|
||
max_tokens=self.max_tokens,
|
||
temperature=self.temperature,
|
||
top_p=self.top_p
|
||
)
|
||
print("Response:")
|
||
print(response.choices[0].message.content)
|
||
|
||
prediction = response.choices[0].message.content
|
||
parsed_responses = self.customize_action_parser(
|
||
prediction,
|
||
self.action_parse_res_factor,
|
||
self.runtime_conf["screen_height"],
|
||
self.runtime_conf["screen_width"]
|
||
)
|
||
break
|
||
except Exception as e:
|
||
logger.exception(f"Error when fetching response from client, with response: {e}")
|
||
prediction = None
|
||
try_times -= 1
|
||
|
||
if prediction is None:
|
||
return "client error", ["DONE"]
|
||
|
||
self.history_responses.append(prediction)
|
||
self.thoughts.append(prediction)
|
||
|
||
try:
|
||
parsed_responses = self.customize_action_parser(
|
||
prediction,
|
||
self.action_parse_res_factor,
|
||
self.runtime_conf["screen_height"],
|
||
self.runtime_conf["screen_width"]
|
||
)
|
||
except Exception as e:
|
||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
|
||
|
||
actions = []
|
||
for parsed_response in parsed_responses:
|
||
if "action_type" in parsed_response:
|
||
|
||
if parsed_response["action_type"] == FINISH_WORD:
|
||
self.actions.append(actions)
|
||
return prediction, ["DONE"]
|
||
|
||
elif parsed_response["action_type"] == WAIT_WORD:
|
||
self.actions.append(actions)
|
||
return prediction, ["WAIT"]
|
||
|
||
elif parsed_response["action_type"] == ENV_FAIL_WORD:
|
||
self.actions.append(actions)
|
||
return prediction, ["FAIL"]
|
||
|
||
elif parsed_response["action_type"] == CALL_USER:
|
||
self.actions.append(actions)
|
||
return prediction, ["FAIL"]
|
||
|
||
pyautogui_code = parsing_response_to_pyautogui_code(
|
||
parsed_response,
|
||
self.runtime_conf["screen_height"],
|
||
self.runtime_conf["screen_width"],
|
||
self.input_swap
|
||
)
|
||
actions.append(pyautogui_code)
|
||
|
||
self.actions.append(actions)
|
||
|
||
if len(self.history_responses) >= self.max_trajectory_length:
|
||
# Default to FAIL if exceed max steps
|
||
actions = ["FAIL"]
|
||
|
||
return prediction, actions
|
||
|
||
@backoff.on_exception(
|
||
backoff.constant,
|
||
# here you should add more model exceptions as you want,
|
||
# but you are forbidden to add "Exception", that is, a common type of exception
|
||
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
||
(
|
||
# General exceptions
|
||
SSLError,
|
||
# OpenAI exceptions
|
||
openai.RateLimitError,
|
||
openai.BadRequestError,
|
||
openai.InternalServerError,
|
||
# Google exceptions
|
||
InvalidArgument,
|
||
ResourceExhausted,
|
||
InternalServerError,
|
||
BadRequest,
|
||
# Groq exceptions
|
||
# todo: check
|
||
),
|
||
interval=30,
|
||
max_tries=10,
|
||
)
|
||
|
||
def reset(self, runtime_logger):
|
||
self.thoughts = []
|
||
self.actions = []
|
||
self.observations = []
|
||
self.history_images = []
|
||
self.history_responses = [] |