Dev/uitars 15 (#178)
* debug uitars1.0, add uitars1.5 * update pyautogui parser * modify function name * update parser
This commit is contained in:
@@ -28,6 +28,7 @@ from mm_agents.prompts import (
|
|||||||
UITARS_CALL_USR_ACTION_SPACE,
|
UITARS_CALL_USR_ACTION_SPACE,
|
||||||
UITARS_USR_PROMPT_NOTHOUGHT,
|
UITARS_USR_PROMPT_NOTHOUGHT,
|
||||||
UITARS_USR_PROMPT_THOUGHT,
|
UITARS_USR_PROMPT_THOUGHT,
|
||||||
|
UITARS_NORMAL_ACTION_SPACE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -38,6 +39,11 @@ WAIT_WORD = "wait"
|
|||||||
ENV_FAIL_WORD = "error_env"
|
ENV_FAIL_WORD = "error_env"
|
||||||
CALL_USER = "call_user"
|
CALL_USER = "call_user"
|
||||||
|
|
||||||
|
IMAGE_FACTOR = 28
|
||||||
|
MIN_PIXELS = 100 * 28 * 28
|
||||||
|
MAX_PIXELS = 16384 * 28 * 28
|
||||||
|
MAX_RATIO = 200
|
||||||
|
|
||||||
pure_text_settings = ["a11y_tree"]
|
pure_text_settings = ["a11y_tree"]
|
||||||
|
|
||||||
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes"
|
||||||
@@ -103,8 +109,68 @@ def escape_single_quotes(text):
|
|||||||
pattern = r"(?<!\\)'"
|
pattern = r"(?<!\\)'"
|
||||||
return re.sub(pattern, r"\\'", text)
|
return re.sub(pattern, r"\\'", text)
|
||||||
|
|
||||||
def parse_action_qwen2vl(text, factor, image_height, image_width):
|
def round_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||||
|
return round(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.ceil(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def floor_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
|
def linear_resize(
|
||||||
|
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
if width * height > max_pixels:
|
||||||
|
"""
|
||||||
|
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||||
|
"""
|
||||||
|
resize_factor = math.sqrt(max_pixels / (width * height))
|
||||||
|
width, height = int(width * resize_factor), int(height * resize_factor)
|
||||||
|
if width * height < min_pixels:
|
||||||
|
resize_factor = math.sqrt(min_pixels / (width * height))
|
||||||
|
width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor)
|
||||||
|
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
def smart_resize(
|
||||||
|
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Rescales the image so that the following conditions are met:
|
||||||
|
|
||||||
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||||
|
|
||||||
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||||
|
|
||||||
|
3. The aspect ratio of the image is maintained as closely as possible.
|
||||||
|
"""
|
||||||
|
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||||
|
raise ValueError(
|
||||||
|
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||||
|
)
|
||||||
|
h_bar = max(factor, round_by_factor(height, factor))
|
||||||
|
w_bar = max(factor, round_by_factor(width, factor))
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = floor_by_factor(height / beta, factor)
|
||||||
|
w_bar = floor_by_factor(width / beta, factor)
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = ceil_by_factor(height * beta, factor)
|
||||||
|
w_bar = ceil_by_factor(width * beta, factor)
|
||||||
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28):
|
||||||
text = text.strip()
|
text = text.strip()
|
||||||
|
if model_type == "qwen25vl":
|
||||||
|
smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
|
||||||
# 正则表达式匹配 Action 字符串
|
# 正则表达式匹配 Action 字符串
|
||||||
if text.startswith("Thought:"):
|
if text.startswith("Thought:"):
|
||||||
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)"
|
||||||
@@ -152,7 +218,7 @@ def parse_action_qwen2vl(text, factor, image_height, image_width):
|
|||||||
for action_instance, raw_str in zip(parsed_actions, all_action):
|
for action_instance, raw_str in zip(parsed_actions, all_action):
|
||||||
if action_instance == None:
|
if action_instance == None:
|
||||||
print(f"Action can't parse: {raw_str}")
|
print(f"Action can't parse: {raw_str}")
|
||||||
continue
|
raise ValueError(f"Action can't parse: {raw_str}")
|
||||||
action_type = action_instance["function"]
|
action_type = action_instance["function"]
|
||||||
params = action_instance["args"]
|
params = action_instance["args"]
|
||||||
|
|
||||||
@@ -170,7 +236,18 @@ def parse_action_qwen2vl(text, factor, image_height, image_width):
|
|||||||
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
numbers = ori_box.replace("(", "").replace(")", "").split(",")
|
||||||
|
|
||||||
# Convert to float and scale by 1000
|
# Convert to float and scale by 1000
|
||||||
float_numbers = [float(num) / factor for num in numbers]
|
# Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates
|
||||||
|
if model_type == "qwen25vl":
|
||||||
|
float_numbers = []
|
||||||
|
for num_idx, num in enumerate(numbers):
|
||||||
|
num = float(num)
|
||||||
|
if (num_idx + 1) % 2 == 0:
|
||||||
|
float_numbers.append(float(num/smart_resize_height))
|
||||||
|
else:
|
||||||
|
float_numbers.append(float(num/smart_resize_width))
|
||||||
|
else:
|
||||||
|
float_numbers = [float(num) / factor for num in numbers]
|
||||||
|
|
||||||
if len(float_numbers) == 2:
|
if len(float_numbers) == 2:
|
||||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||||
action_inputs[param_name.strip()] = str(float_numbers)
|
action_inputs[param_name.strip()] = str(float_numbers)
|
||||||
@@ -219,7 +296,7 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||||||
if response_id == 0:
|
if response_id == 0:
|
||||||
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n"
|
||||||
else:
|
else:
|
||||||
pyautogui_code += f"\ntime.sleep(3)\n"
|
pyautogui_code += f"\ntime.sleep(1)\n"
|
||||||
|
|
||||||
action_dict = response
|
action_dict = response
|
||||||
action_type = action_dict.get("action_type")
|
action_type = action_dict.get("action_type")
|
||||||
@@ -232,25 +309,79 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||||||
else:
|
else:
|
||||||
hotkey = action_inputs.get("hotkey", "")
|
hotkey = action_inputs.get("hotkey", "")
|
||||||
|
|
||||||
|
if hotkey == "arrowleft":
|
||||||
|
hotkey = "left"
|
||||||
|
|
||||||
|
elif hotkey == "arrowright":
|
||||||
|
hotkey = "right"
|
||||||
|
|
||||||
|
elif hotkey == "arrowup":
|
||||||
|
hotkey = "up"
|
||||||
|
|
||||||
|
elif hotkey == "arrowdown":
|
||||||
|
hotkey = "down"
|
||||||
|
|
||||||
if hotkey:
|
if hotkey:
|
||||||
# Handle other hotkeys
|
# Handle other hotkeys
|
||||||
keys = hotkey.split() # Split the keys by space
|
keys = hotkey.split() # Split the keys by space
|
||||||
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})"
|
convert_keys = []
|
||||||
|
for key in keys:
|
||||||
|
if key == "space":
|
||||||
|
key = ' '
|
||||||
|
convert_keys.append(key)
|
||||||
|
pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})"
|
||||||
|
|
||||||
|
elif action_type == "press":
|
||||||
|
# Parsing press action
|
||||||
|
if "key" in action_inputs:
|
||||||
|
key_to_press = action_inputs.get("key", "")
|
||||||
|
else:
|
||||||
|
key_to_press = action_inputs.get("press", "")
|
||||||
|
|
||||||
|
if hotkey == "arrowleft":
|
||||||
|
hotkey = "left"
|
||||||
|
|
||||||
|
elif hotkey == "arrowright":
|
||||||
|
hotkey = "right"
|
||||||
|
|
||||||
|
elif hotkey == "arrowup":
|
||||||
|
hotkey = "up"
|
||||||
|
|
||||||
|
elif hotkey == "arrowdown":
|
||||||
|
hotkey = "down"
|
||||||
|
|
||||||
|
elif hotkey == "space":
|
||||||
|
hotkey = " "
|
||||||
|
|
||||||
|
if key_to_press:
|
||||||
|
# Simulate pressing a single key
|
||||||
|
pyautogui_code += f"\npyautogui.press({repr(key_to_press)})"
|
||||||
|
|
||||||
|
elif action_type == "keyup":
|
||||||
|
key_to_up = action_inputs.get("key", "")
|
||||||
|
pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})"
|
||||||
|
|
||||||
|
elif action_type == "keydown":
|
||||||
|
key_to_down = action_inputs.get("key", "")
|
||||||
|
pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})"
|
||||||
|
|
||||||
elif action_type == "type":
|
elif action_type == "type":
|
||||||
# Parsing typing action using clipboard
|
# Parsing typing action using clipboard
|
||||||
content = action_inputs.get("content", "")
|
content = action_inputs.get("content", "")
|
||||||
content = escape_single_quotes(content)
|
content = escape_single_quotes(content)
|
||||||
|
stripped_content = content
|
||||||
|
if content.endswith("\n") or content.endswith("\\n"):
|
||||||
|
stripped_content = stripped_content.rstrip("\\n").rstrip("\n")
|
||||||
if content:
|
if content:
|
||||||
if input_swap:
|
if input_swap:
|
||||||
pyautogui_code += f"\nimport pyperclip"
|
pyautogui_code += f"\nimport pyperclip"
|
||||||
pyautogui_code += f"\npyperclip.copy('{content.strip()}')"
|
pyautogui_code += f"\npyperclip.copy('{stripped_content}')"
|
||||||
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')"
|
||||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||||
if content.endswith("\n") or content.endswith("\\n"):
|
if content.endswith("\n") or content.endswith("\\n"):
|
||||||
pyautogui_code += f"\npyautogui.press('enter')"
|
pyautogui_code += f"\npyautogui.press('enter')"
|
||||||
else:
|
else:
|
||||||
pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)"
|
pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)"
|
||||||
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
pyautogui_code += f"\ntime.sleep(0.5)\n"
|
||||||
if content.endswith("\n") or content.endswith("\\n"):
|
if content.endswith("\n") or content.endswith("\\n"):
|
||||||
pyautogui_code += f"\npyautogui.press('enter')"
|
pyautogui_code += f"\npyautogui.press('enter')"
|
||||||
@@ -329,6 +460,29 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width
|
|||||||
|
|
||||||
return pyautogui_code
|
return pyautogui_code
|
||||||
|
|
||||||
|
def add_box_token(input_string):
|
||||||
|
# Step 1: Split the string into individual actions
|
||||||
|
if "Action: " in input_string and "start_box=" in input_string:
|
||||||
|
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||||
|
actions = input_string.split("Action: ")[1:]
|
||||||
|
processed_actions = []
|
||||||
|
for action in actions:
|
||||||
|
action = action.strip()
|
||||||
|
# Step 2: Extract coordinates (start_box or end_box) using regex
|
||||||
|
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||||
|
|
||||||
|
updated_action = action # Start with the original action
|
||||||
|
for coord_type, x, y in coordinates:
|
||||||
|
# Convert x and y to integers
|
||||||
|
updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'")
|
||||||
|
processed_actions.append(updated_action)
|
||||||
|
|
||||||
|
# Step 5: Reconstruct the final string
|
||||||
|
final_string = suffix + "\n\n".join(processed_actions)
|
||||||
|
else:
|
||||||
|
final_string = input_string
|
||||||
|
return final_string
|
||||||
|
|
||||||
def pil_to_base64(image):
|
def pil_to_base64(image):
|
||||||
buffer = BytesIO()
|
buffer = BytesIO()
|
||||||
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
|
image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式
|
||||||
@@ -405,45 +559,50 @@ class UITARSAgent:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
platform="ubuntu",
|
platform="ubuntu",
|
||||||
max_tokens=1000,
|
|
||||||
top_p=0.9,
|
|
||||||
top_k=1.0,
|
|
||||||
temperature=0.0,
|
|
||||||
action_space="pyautogui",
|
action_space="pyautogui",
|
||||||
observation_type="screenshot_a11y_tree",
|
observation_type="screenshot",
|
||||||
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
||||||
max_trajectory_length=50,
|
max_trajectory_length=50,
|
||||||
a11y_tree_max_tokens=10000,
|
a11y_tree_max_tokens=10000,
|
||||||
|
model_type="qwen25vl",
|
||||||
runtime_conf: dict = {
|
runtime_conf: dict = {
|
||||||
"infer_mode": "qwen2vl_user",
|
"infer_mode": "qwen25vl_normal",
|
||||||
"prompt_style": "qwen2vl_user",
|
"prompt_style": "qwen25vl_normal",
|
||||||
"input_swap": True,
|
"input_swap": True,
|
||||||
"language": "Chinese",
|
"language": "Chinese",
|
||||||
"max_steps": 50,
|
|
||||||
"history_n": 5,
|
"history_n": 5,
|
||||||
"screen_height": 1080,
|
"max_pixels": 16384*28*28,
|
||||||
"screen_width": 1920
|
"min_pixels": 100*28*28,
|
||||||
|
"callusr_tolerance": 3,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": -1,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": 500
|
||||||
|
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
self.platform = platform
|
self.platform = platform
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.top_p = top_p
|
|
||||||
self.top_k = top_k
|
|
||||||
self.temperature = temperature
|
|
||||||
self.action_space = action_space
|
self.action_space = action_space
|
||||||
self.observation_type = observation_type
|
self.observation_type = observation_type
|
||||||
self.max_trajectory_length = max_trajectory_length
|
self.max_trajectory_length = max_trajectory_length
|
||||||
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
||||||
|
self.model_type = model_type
|
||||||
self.runtime_conf = runtime_conf
|
self.runtime_conf = runtime_conf
|
||||||
self.vlm = OpenAI(
|
self.vlm = OpenAI(
|
||||||
base_url="http://127.0.0.1:8000/v1",
|
base_url="http://127.0.0.1:8000/v1",
|
||||||
api_key="empty",
|
api_key="empty",
|
||||||
) # should replace with your UI-TARS server api
|
) # should replace with your UI-TARS server api
|
||||||
|
self.temperature = self.runtime_conf["temperature"]
|
||||||
|
self.top_k = self.runtime_conf["top_k"]
|
||||||
|
self.top_p = self.runtime_conf["top_p"]
|
||||||
|
self.max_tokens = self.runtime_conf["max_tokens"]
|
||||||
self.infer_mode = self.runtime_conf["infer_mode"]
|
self.infer_mode = self.runtime_conf["infer_mode"]
|
||||||
self.prompt_style = self.runtime_conf["prompt_style"]
|
self.prompt_style = self.runtime_conf["prompt_style"]
|
||||||
self.input_swap = self.runtime_conf["input_swap"]
|
self.input_swap = self.runtime_conf["input_swap"]
|
||||||
self.language = self.runtime_conf["language"]
|
self.language = self.runtime_conf["language"]
|
||||||
self.max_steps = self.runtime_conf["max_steps"]
|
self.max_pixels = self.runtime_conf["max_pixels"]
|
||||||
|
self.min_pixels = self.runtime_conf["min_pixels"]
|
||||||
|
self.callusr_tolerance = self.runtime_conf["callusr_tolerance"]
|
||||||
|
|
||||||
self.thoughts = []
|
self.thoughts = []
|
||||||
self.actions = []
|
self.actions = []
|
||||||
@@ -452,14 +611,15 @@ class UITARSAgent:
|
|||||||
self.history_responses = []
|
self.history_responses = []
|
||||||
|
|
||||||
self.prompt_action_space = UITARS_ACTION_SPACE
|
self.prompt_action_space = UITARS_ACTION_SPACE
|
||||||
self.customize_action_parser = parse_action_qwen2vl
|
|
||||||
self.action_parse_res_factor = 1000
|
self.action_parse_res_factor = 1000
|
||||||
if self.infer_mode == "qwen2vl_user":
|
if self.infer_mode == "qwen2vl_user":
|
||||||
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE
|
||||||
|
elif self.infer_mode == "qwen25vl_normal":
|
||||||
|
self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE
|
||||||
|
|
||||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||||
|
|
||||||
if self.prompt_style == "qwen2vl_user":
|
if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal":
|
||||||
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
self.prompt_template = UITARS_USR_PROMPT_THOUGHT
|
||||||
|
|
||||||
elif self.prompt_style == "qwen2vl_no_thought":
|
elif self.prompt_style == "qwen2vl_no_thought":
|
||||||
@@ -470,6 +630,8 @@ class UITARSAgent:
|
|||||||
self.history_n = self.runtime_conf["history_n"]
|
self.history_n = self.runtime_conf["history_n"]
|
||||||
else:
|
else:
|
||||||
self.history_n = 5
|
self.history_n = 5
|
||||||
|
|
||||||
|
self.cur_callusr_count = 0
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
self, instruction: str, obs: Dict, last_action_after_obs: Dict = None
|
||||||
@@ -511,9 +673,6 @@ class UITARSAgent:
|
|||||||
"Invalid observation_type type: " + self.observation_type
|
"Invalid observation_type type: " + self.observation_type
|
||||||
) # 1}}}
|
) # 1}}}
|
||||||
|
|
||||||
if last_action_after_obs is not None and self.infer_mode == "double_image":
|
|
||||||
self.history_images.append(last_action_after_obs["screenshot"])
|
|
||||||
|
|
||||||
self.history_images.append(obs["screenshot"])
|
self.history_images.append(obs["screenshot"])
|
||||||
|
|
||||||
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
||||||
@@ -553,7 +712,7 @@ class UITARSAgent:
|
|||||||
"Invalid observation_type type: " + self.observation_type
|
"Invalid observation_type type: " + self.observation_type
|
||||||
) # 1}}}
|
) # 1}}}
|
||||||
|
|
||||||
if self.infer_mode == "qwen2vl_user":
|
if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal":
|
||||||
user_prompt = self.prompt_template.format(
|
user_prompt = self.prompt_template.format(
|
||||||
instruction=instruction,
|
instruction=instruction,
|
||||||
action_space=self.prompt_action_space,
|
action_space=self.prompt_action_space,
|
||||||
@@ -567,8 +726,6 @@ class UITARSAgent:
|
|||||||
if len(self.history_images) > self.history_n:
|
if len(self.history_images) > self.history_n:
|
||||||
self.history_images = self.history_images[-self.history_n:]
|
self.history_images = self.history_images[-self.history_n:]
|
||||||
|
|
||||||
max_pixels = 1350 * 28 * 28
|
|
||||||
min_pixels = 100 * 28 * 28
|
|
||||||
messages, images = [], []
|
messages, images = [], []
|
||||||
if isinstance(self.history_images, bytes):
|
if isinstance(self.history_images, bytes):
|
||||||
self.history_images = [self.history_images]
|
self.history_images = [self.history_images]
|
||||||
@@ -578,28 +735,24 @@ class UITARSAgent:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
|
raise TypeError(f"Unidentified images type: {type(self.history_images)}")
|
||||||
max_image_nums_under_32k = int(32768*0.75/max_pixels*28*28)
|
|
||||||
if len(self.history_images) > max_image_nums_under_32k:
|
|
||||||
num_of_images = min(5, len(self.history_images))
|
|
||||||
max_pixels = int(32768*0.75) // num_of_images
|
|
||||||
|
|
||||||
for turn, image in enumerate(self.history_images):
|
for turn, image in enumerate(self.history_images):
|
||||||
if len(images) >= 5:
|
if len(images) >= self.history_n:
|
||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
image = Image.open(BytesIO(image))
|
image = Image.open(BytesIO(image))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error opening image: {e}")
|
raise RuntimeError(f"Error opening image: {e}")
|
||||||
|
|
||||||
if image.width * image.height > max_pixels:
|
if image.width * image.height > self.max_pixels:
|
||||||
"""
|
"""
|
||||||
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用
|
||||||
"""
|
"""
|
||||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
resize_factor = math.sqrt(self.max_pixels / (image.width * image.height))
|
||||||
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
width, height = int(image.width * resize_factor), int(image.height * resize_factor)
|
||||||
image = image.resize((width, height))
|
image = image.resize((width, height))
|
||||||
if image.width * image.height < min_pixels:
|
if image.width * image.height < self.min_pixels:
|
||||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
resize_factor = math.sqrt(self.min_pixels / (image.width * image.height))
|
||||||
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
|
width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor)
|
||||||
image = image.resize((width, height))
|
image = image.resize((width, height))
|
||||||
|
|
||||||
@@ -635,7 +788,7 @@ class UITARSAgent:
|
|||||||
|
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": [history_response]
|
"content": [add_box_token(history_response)]
|
||||||
})
|
})
|
||||||
|
|
||||||
cur_image = images[image_num]
|
cur_image = images[image_num]
|
||||||
@@ -656,56 +809,75 @@ class UITARSAgent:
|
|||||||
image_num += 1
|
image_num += 1
|
||||||
|
|
||||||
try_times = 3
|
try_times = 3
|
||||||
|
origin_resized_height = images[-1].height
|
||||||
|
origin_resized_width = images[-1].width
|
||||||
|
temperature = self.temperature
|
||||||
|
top_k = self.top_k
|
||||||
while True:
|
while True:
|
||||||
if try_times <= 0:
|
if try_times <= 0:
|
||||||
print(f"Reach max retry times to fetch response from client, as error flag.")
|
print(f"Reach max retry times to fetch response from client, as error flag.")
|
||||||
return "client error", ["DONE"], []
|
return "client error", ["DONE"], []
|
||||||
try:
|
try:
|
||||||
|
|
||||||
response = self.vlm.chat.completions.create(
|
response = self.vlm.chat.completions.create(
|
||||||
model="ui-tars",
|
model="ui-tars",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=1,
|
frequency_penalty=1,
|
||||||
max_tokens=self.max_tokens,
|
max_tokens=self.max_tokens,
|
||||||
temperature=self.temperature,
|
temperature=temperature,
|
||||||
top_k=self.top_k,
|
top_k=top_k,
|
||||||
top_p=self.top_p
|
top_p=self.top_p
|
||||||
)
|
)
|
||||||
# print(response.choices[0].message.content)
|
# print(response.choices[0].message.content)
|
||||||
prediction = response.choices[0].message.content.strip()
|
prediction = response.choices[0].message.content.strip()
|
||||||
|
|
||||||
prediction = response[0]["prediction"].strip()
|
prediction = response[0]["prediction"].strip()
|
||||||
parsed_responses = self.customize_action_parser(
|
|
||||||
prediction,
|
|
||||||
self.action_parse_res_factor,
|
|
||||||
self.runtime_conf["screen_height"],
|
|
||||||
self.runtime_conf["screen_width"]
|
|
||||||
)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error when fetching response from client, with response: {response}")
|
print(f"Error when fetching response from client, with response: {response}")
|
||||||
prediction = None
|
prediction = None
|
||||||
try_times -= 1
|
try_times -= 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed_responses = parse_action_to_structure_output(
|
||||||
|
prediction,
|
||||||
|
self.action_parse_res_factor,
|
||||||
|
origin_resized_height,
|
||||||
|
origin_resized_width,
|
||||||
|
self.model_type,
|
||||||
|
self.max_pixels,
|
||||||
|
self.min_pixels
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error when parsing response from client, with response: {response}")
|
||||||
|
# If fail to parse the model response, we use sampling parameters to avoid it
|
||||||
|
prediction = None
|
||||||
|
try_times -= 1
|
||||||
|
temperature = 1
|
||||||
|
top_k = -1
|
||||||
|
|
||||||
if prediction is None:
|
if prediction is None:
|
||||||
return "client error", ["DONE"]
|
return "client error", ["DONE"]
|
||||||
|
|
||||||
|
|
||||||
self.history_responses.append(prediction)
|
self.history_responses.append(prediction)
|
||||||
self.thoughts.append(prediction)
|
self.thoughts.append(prediction)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
parsed_responses = self.customize_action_parser(
|
parsed_responses = parse_action_to_structure_output(
|
||||||
prediction,
|
prediction,
|
||||||
self.action_parse_res_factor,
|
self.action_parse_res_factor,
|
||||||
self.runtime_conf["screen_height"],
|
origin_resized_height,
|
||||||
self.runtime_conf["screen_width"]
|
origin_resized_width,
|
||||||
|
self.model_type,
|
||||||
|
self.max_pixels,
|
||||||
|
self.min_pixels
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
print(f"Parsing action error: {prediction}, with error:\n{e}")
|
||||||
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
|
return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"]
|
||||||
|
|
||||||
actions = []
|
actions = []
|
||||||
|
last_image = Image.open(BytesIO(self.history_images[-1]))
|
||||||
|
obs_image_height = last_image.height
|
||||||
|
obs_image_width = last_image.width
|
||||||
for parsed_response in parsed_responses:
|
for parsed_response in parsed_responses:
|
||||||
if "action_type" in parsed_response:
|
if "action_type" in parsed_response:
|
||||||
|
|
||||||
@@ -723,14 +895,18 @@ class UITARSAgent:
|
|||||||
return prediction, ["FAIL"]
|
return prediction, ["FAIL"]
|
||||||
|
|
||||||
elif parsed_response["action_type"] == CALL_USER:
|
elif parsed_response["action_type"] == CALL_USER:
|
||||||
self.actions.append(actions)
|
if self.callusr_tolerance > self.cur_callusr_count:
|
||||||
return prediction, ["FAIL"]
|
self.actions.append(actions)
|
||||||
|
self.cur_callusr_count += 1
|
||||||
|
return prediction, ["WAIT"]
|
||||||
|
else:
|
||||||
|
self.actions.append(actions)
|
||||||
|
return prediction, ["FAIL"]
|
||||||
|
|
||||||
pyautogui_code = parsing_response_to_pyautogui_code(
|
pyautogui_code = parsing_response_to_pyautogui_code(
|
||||||
parsed_response,
|
parsed_response,
|
||||||
self.runtime_conf["screen_height"],
|
obs_image_height,
|
||||||
self.runtime_conf["screen_width"],
|
obs_image_width,
|
||||||
self.input_swap
|
self.input_swap
|
||||||
)
|
)
|
||||||
actions.append(pyautogui_code)
|
actions.append(pyautogui_code)
|
||||||
|
|||||||
@@ -91,10 +91,20 @@ def config() -> argparse.Namespace:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# lm config
|
# lm config
|
||||||
parser.add_argument("--model", type=str, default="gpt-4o")
|
parser.add_argument("--model", type=str, default="uitars")
|
||||||
|
parser.add_argument("--model_type", type=str, default="qwen25vl")
|
||||||
|
parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal")
|
||||||
|
parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal")
|
||||||
|
parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content")
|
||||||
|
parser.add_argument("--language", type=str, default="Chinese")
|
||||||
|
parser.add_argument("--max_pixels", type=float, default=16384*28*28)
|
||||||
|
parser.add_argument("--min_pixels", type=float, default=100*28*28)
|
||||||
parser.add_argument("--temperature", type=float, default=1.0)
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
parser.add_argument("--top_p", type=float, default=0.9)
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
parser.add_argument("--max_tokens", type=int, default=1500)
|
parser.add_argument("--top_k", type=int, default=-1)
|
||||||
|
parser.add_argument("--history_n", type=int, default=5)
|
||||||
|
parser.add_argument("--callusr_tolerance", type=int, default=3)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=500)
|
||||||
parser.add_argument("--stop_token", type=str, default=None)
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
# example config
|
# example config
|
||||||
@@ -128,8 +138,18 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
"max_steps": args.max_steps,
|
"max_steps": args.max_steps,
|
||||||
"max_trajectory_length": args.max_trajectory_length,
|
"max_trajectory_length": args.max_trajectory_length,
|
||||||
"model": args.model,
|
"model": args.model,
|
||||||
|
"model_type": args.model_type,
|
||||||
|
"infer_mode": args.infer_mode,
|
||||||
|
"prompt_style": args.prompt_style,
|
||||||
|
"input_swap": args.input_swap,
|
||||||
|
"language": args.language,
|
||||||
|
"history_n": args.history_n,
|
||||||
|
"max_pixels": args.max_pixels,
|
||||||
|
"min_pixels": args.min_pixels,
|
||||||
|
"callusr_tolerance": args.callusr_tolerance,
|
||||||
"temperature": args.temperature,
|
"temperature": args.temperature,
|
||||||
"top_p": args.top_p,
|
"top_p": args.top_p,
|
||||||
|
"top_k": args.top_k,
|
||||||
"max_tokens": args.max_tokens,
|
"max_tokens": args.max_tokens,
|
||||||
"stop_token": args.stop_token,
|
"stop_token": args.stop_token,
|
||||||
"result_dir": args.result_dir,
|
"result_dir": args.result_dir,
|
||||||
@@ -137,12 +157,24 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
|||||||
|
|
||||||
agent = UITARSAgent(
|
agent = UITARSAgent(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
max_tokens=args.max_tokens,
|
|
||||||
top_p=args.top_p,
|
|
||||||
temperature=args.temperature,
|
|
||||||
action_space=args.action_space,
|
action_space=args.action_space,
|
||||||
observation_type=args.observation_type,
|
observation_type=args.observation_type,
|
||||||
max_trajectory_length=args.max_trajectory_length,
|
max_trajectory_length=args.max_trajectory_length,
|
||||||
|
model_type=args.model_type,
|
||||||
|
runtime_conf = {
|
||||||
|
"infer_mode": args.infer_mode,
|
||||||
|
"prompt_style": args.prompt_style,
|
||||||
|
"input_swap": args.input_swap,
|
||||||
|
"language": args.language,
|
||||||
|
"history_n": args.history_n,
|
||||||
|
"max_pixels": args.max_pixels,
|
||||||
|
"min_pixels": args.min_pixels,
|
||||||
|
"callusr_tolerance": args.callusr_tolerance,
|
||||||
|
"temperature": args.temperature,
|
||||||
|
"top_p": args.top_p,
|
||||||
|
"top_k": args.top_k,
|
||||||
|
"max_tokens": args.max_tokens
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
env = DesktopEnv(
|
env = DesktopEnv(
|
||||||
|
|||||||
Reference in New Issue
Block a user