diff --git a/mm_agents/uitars15_agent.py b/mm_agents/uitars15_agent.py new file mode 100644 index 0000000..d8213b7 --- /dev/null +++ b/mm_agents/uitars15_agent.py @@ -0,0 +1,849 @@ + +import os +import re +import base64 +import requests +from typing import Optional, Dict, List, Tuple +from loguru import logger + +import ast +import base64 +import math +import re + +FINISH_WORD = "finished" +WAIT_WORD = "wait" +ENV_FAIL_WORD = "error_env" +CALL_USER = "call_user" + +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +def convert_point_to_coordinates(text, is_answer=False): + # 匹配 后面的四个数字 + pattern = r"(\d+)\s+(\d+)" + + def replace_match(match): + x1, y1= map(int, match.groups()) + x = (x1 + x1) // 2 # 使用截断取整 + y = (y1 + y1) // 2 # 使用截断取整 + if is_answer: + return f"({x},{y})" # 只返回 (x, y) 格式 + return f"({x},{y})" # 返回带标签的格式 + + # 去掉 [EOS] 并替换 坐标 + text = re.sub(r"\[EOS\]", "", text) + return re.sub(pattern, replace_match, text).strip() + +# 定义一个函数来解析每个 action +def parse_action(action_str): + try: + # 解析字符串为 AST 节点 + node = ast.parse(action_str, mode='eval') + + # 确保节点是一个表达式 + if not isinstance(node, ast.Expression): + raise ValueError("Not an expression") + + # 获取表达式的主体 + call = node.body + + # 确保主体是一个函数调用 + if not isinstance(call, ast.Call): + raise ValueError("Not a function call") + + # 获取函数名 + if isinstance(call.func, ast.Name): + func_name = call.func.id + elif isinstance(call.func, ast.Attribute): + func_name = call.func.attr + else: + func_name = None + + # 获取关键字参数 + kwargs = {} + for kw in call.keywords: + key = kw.arg + # 处理不同类型的值,这里假设都是常量 + if isinstance(kw.value, ast.Constant): + value = kw.value.value + elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python + value = kw.value.s + else: + value = None + kwargs[key] = value + + return { + 'function': func_name, + 'args': kwargs + } + + except Exception as e: + print(f"Failed to parse action '{action_str}': {e}") + return None + +def escape_single_quotes(text): + # 匹配未转义的单引号(不匹配 \\') + pattern = r"(? int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + +def linear_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + if width * height > max_pixels: + """ + 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 + """ + resize_factor = math.sqrt(max_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + if width * height < min_pixels: + resize_factor = math.sqrt(min_pixels / (width * height)) + width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor) + + return height, width + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + +def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type="qwen25vl", max_pixels=16384*28*28, min_pixels=100*28*28): + text = text.strip() + + if "" in text: + text = convert_point_to_coordinates(text) + if "start_point=" in text: + text = text.replace("start_point=", "start_box=") + if "end_point=" in text: + text = text.replace("end_point=", "end_box=") + if "point=" in text: + text = text.replace("point=", "start_box=") + + if model_type == "qwen25vl": + smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) + + # 正则表达式匹配 Action 字符串 + if text.startswith("Thought:"): + thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)" + thought_hint = "Thought: " + elif text.startswith("Reflection:"): + thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action: |$)" + thought_hint = "Reflection: " + elif text.startswith("Action_Summary:"): + thought_pattern = r"Action_Summary: (.+?)(?=\s*Action: |$)" + thought_hint = "Action_Summary: " + else: + thought_pattern = r"Thought: (.+?)(?=\s*Action: |$)" + thought_hint = "Thought: " + reflection, thought = None, None + thought_match = re.search(thought_pattern, text, re.DOTALL) + if thought_match: + if len(thought_match.groups()) == 1: + thought = thought_match.group(1).strip() + elif len(thought_match.groups()) == 2: + thought = thought_match.group(2).strip() + reflection = thought_match.group(1).strip() + assert "Action:" in text + action_str = text.split("Action: ")[-1] + + tmp_all_action = action_str.split("')\n\n") + all_action = [] + for action_str in tmp_all_action: + if "type(content" in action_str: + # 正则表达式匹配 content 中的字符串并转义单引号 + def escape_quotes(match): + content = match.group(1) # 获取 content 的值 + return content + + # 使用正则表达式进行替换 + pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...') + content = re.sub(pattern, escape_quotes, action_str) + + # 处理字符串 + action_str = escape_single_quotes(content) + action_str = "type(content='" + action_str + "')" + all_action.append(action_str) + + parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action] + actions = [] + for action_instance, raw_str in zip(parsed_actions, all_action): + if action_instance == None: + print(f"Action can't parse: {raw_str}") + raise ValueError(f"Action can't parse: {raw_str}") + action_type = action_instance["function"] + params = action_instance["args"] + + # import pdb; pdb.set_trace() + action_inputs = {} + for param_name, param in params.items(): + if param == "": continue + param = param.lstrip() # 去掉引号和多余的空格 + # 处理start_box或者end_box参数格式 'x1 y1 x2 y2' + action_inputs[param_name.strip()] = param + + if "start_box" in param_name or "end_box" in param_name: + ori_box = param + # Remove parentheses and split the string by commas + numbers = ori_box.replace("(", "").replace(")", "").split(",") + + # Convert to float and scale by 1000 + # Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates + if model_type == "qwen25vl": + float_numbers = [] + for num_idx, num in enumerate(numbers): + num = float(num) + if (num_idx + 1) % 2 == 0: + float_numbers.append(float(num/smart_resize_height)) + else: + float_numbers.append(float(num/smart_resize_width)) + else: + float_numbers = [float(num) / factor for num in numbers] + + if len(float_numbers) == 2: + float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] + action_inputs[param_name.strip()] = str(float_numbers) + + # import pdb; pdb.set_trace() + actions.append({ + "reflection": reflection, + "thought": thought, + "action_type": action_type, + "action_inputs": action_inputs, + "text": text + }) + return actions + +def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True, platform:str="Ubuntu") -> str: + ''' + 将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串 + 参数: + response: 包含模型输出的字典,结构类似于: + { + "action_type": "hotkey", + "action_inputs": { + "hotkey": "v ctrl", + "start_box": None, + "end_box": None + } + } + 返回: + 生成的pyautogui代码字符串 + ''' + + pyautogui_code = f"import pyautogui\nimport time\n" + if isinstance(responses, dict): + responses = [responses] + for response_id, response in enumerate(responses): + if "observation" in response: + observation = response["observation"] + else: + observation = "" + + if "thought" in response: + thought = response["thought"] + else: + thought = "" + + if response_id == 0: + pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" + else: + pyautogui_code += f"\ntime.sleep(1)\n" + + action_dict = response + action_type = action_dict.get("action_type") + action_inputs = action_dict.get("action_inputs", {}) + + if action_type == "hotkey": + # Parsing hotkey action + if "key" in action_inputs: + hotkey = action_inputs.get("key", "") + else: + hotkey = action_inputs.get("hotkey", "") + + if hotkey == "arrowleft": + hotkey = "left" + + elif hotkey == "arrowright": + hotkey = "right" + + elif hotkey == "arrowup": + hotkey = "up" + + elif hotkey == "arrowdown": + hotkey = "down" + + if hotkey: + # Handle other hotkeys + keys = hotkey.split() # Split the keys by space + convert_keys = [] + for key in keys: + if key == "space": + key = ' ' + convert_keys.append(key) + pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})" + + elif action_type in ["press", "keydown"]: + # Parsing press action + if "key" in action_inputs: + key_to_press = action_inputs.get("key", "") + else: + key_to_press = action_inputs.get("press", "") + + if key_to_press == "arrowleft": + key_to_press = "left" + + elif key_to_press == "arrowright": + key_to_press = "right" + + elif key_to_press == "arrowup": + key_to_press = "up" + + elif key_to_press == "arrowdown": + key_to_press = "down" + + elif key_to_press == "space": + key_to_press = " " + + if key_to_press: + # Simulate pressing a single key + pyautogui_code += f"\npyautogui.keyDown({repr(key_to_press)})" + + elif action_type in ["release", "keyup"]: + # Parsing press action + if "key" in action_inputs: + key_to_press = action_inputs.get("key", "") + else: + key_to_press = action_inputs.get("press", "") + + if key_to_press == "arrowleft": + key_to_press = "left" + + elif key_to_press == "arrowright": + key_to_press = "right" + + elif key_to_press == "arrowup": + key_to_press = "up" + + elif key_to_press == "arrowdown": + key_to_press = "down" + + elif key_to_press == "space": + key_to_press = " " + + if key_to_press: + # Simulate pressing a single key + pyautogui_code += f"\npyautogui.keyUp({repr(key_to_press)})" + + elif action_type == "type": + # Parsing typing action using clipboard + content = action_inputs.get("content", "") + content = escape_single_quotes(content) + stripped_content = content + if content.endswith("\n") or content.endswith("\\n"): + stripped_content = stripped_content.rstrip("\\n").rstrip("\n") + if content: + if input_swap: + pyautogui_code += f"\nimport pyperclip" + pyautogui_code += f"\npyperclip.copy('{stripped_content}')" + pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + else: + pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + + + elif action_type in ["drag", "select"]: + # Parsing drag or select action based on start and end_boxes + start_box = action_inputs.get("start_box") + end_box = action_inputs.get("end_box") + if start_box and end_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + sx = round(float((x1 + x2) / 2) * image_width, 3) + sy = round(float((y1 + y2) / 2) * image_height, 3) + x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2] + ex = round(float((x1 + x2) / 2) * image_width, 3) + ey = round(float((y1 + y2) / 2) * image_height, 3) + pyautogui_code += ( + f"\npyautogui.moveTo({sx}, {sy})\n" + f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n" + ) + + elif action_type == "scroll": + # Parsing scroll action + start_box = action_inputs.get("start_box") + if start_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + + # # 先点对应区域,再滚动 + # pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + else: + x = None + y = None + direction = action_inputs.get("direction", "") + + if x == None: + if "up" in direction.lower(): + if platform.lower() == "ubuntu": + pyautogui_code += f"\npyautogui.scroll(-5)" + elif platform.lower() == "windows": + pyautogui_code += f"\npyautogui.scroll(-50)" + elif "down" in direction.lower(): + if platform.lower() == "ubuntu": + pyautogui_code += f"\npyautogui.scroll(5)" + elif platform.lower() == "windows": + pyautogui_code += f"\npyautogui.scroll(50)" + else: + if "up" in direction.lower(): + if platform.lower() == "ubuntu": + pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" + elif platform.lower() == "windows": + pyautogui_code += f"\npyautogui.scroll(50, x={x}, y={y})" + elif "down" in direction.lower(): + if platform.lower() == "ubuntu": + pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})" + elif platform.lower() == "windows": + pyautogui_code += f"\npyautogui.scroll(-50, x={x}, y={y})" + + elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]: + # Parsing mouse click actions + start_box = action_inputs.get("start_box") + start_box = str(start_box) + if start_box: + start_box = eval(start_box) + if len(start_box) == 4: + x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2] + elif len(start_box) == 2: + x1, y1 = start_box + x2 = x1 + y2 = y1 + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + if action_type == "left_single" or action_type == "click": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + elif action_type == "left_double": + pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')" + elif action_type == "right_single": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')" + elif action_type == "hover": + pyautogui_code += f"\npyautogui.moveTo({x}, {y})" + + elif action_type in ["finished"]: + pyautogui_code = f"DONE" + + else: + pyautogui_code += f"\n# Unrecognized action type: {action_type}" + + return pyautogui_code + +def add_box_token(input_string): + # Step 1: Split the string into individual actions + if "Action: " in input_string and "start_box=" in input_string: + suffix = input_string.split("Action: ")[0] + "Action: " + actions = input_string.split("Action: ")[1:] + processed_actions = [] + for action in actions: + action = action.strip() + # Step 2: Extract coordinates (start_box or end_box) using regex + coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action) + + updated_action = action # Start with the original action + for coord_type, x, y in coordinates: + # Convert x and y to integers + updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'") + processed_actions.append(updated_action) + + # Step 5: Reconstruct the final string + final_string = suffix + "\n\n".join(processed_actions) + else: + final_string = input_string + return final_string + +COMPUTER_USE_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +You should first think about the reasoning process in the mind and then provide the user with the answer. +The reasoning process is enclosed within tags +After the tags, you should place final answer, which concludes your summarized thought and your action. + +For example, +``` +detailed reasoning content here +Thought: a small plan and finally summarize your next action (with its target element) in one sentence +Action: ... +``` + +## Action Space + +click(point='x1 y1') +left_double(point='x1 y1') +right_single(point='x1 y1') +drag(start_point='x1 y1', end_point='x2 y2') +hotkey(key='ctrl c') # Split keys with a space and use lowercase. Also, do not use more than 3 keys in one hotkey action. +type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content. +scroll(point='x1 y1', direction='down or up or right or left') # Show more information on the `direction` side. +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + +## Output Example +Now that... +Thought: Let's click ... +Action: click(point='100 200') + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. +- If you have executed several same actions (like repeatedly clicking the same point) but the screen keeps no change, please try to execute a modified action when necessary. + +## User Instruction +{instruction} +""" + +MOBILE_USE_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. +## Output Format +``` +Thought: ... +Action: ... +``` +## Action Space + +click(point='x1 y1') +long_press(point='x1 y1') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(point='x1 y1', direction='down or up or right or left') +open_app(app_name=\'\') +drag(start_point='x1 y1', end_point='x2 y2') +press_home() +press_back() +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. + + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. + +## User Instruction +{instruction} +""" + +GROUNDING_DOUBAO = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. \n\n## Output Format\n\nAction: ...\n\n\n## Action Space\nclick(point='x1 y1'')\n\n## User Instruction +{instruction}""" + + + +class UITarsAgent: + """ + UI-TARS Agent based on Seed1.5-VL model implementation. + Integrates the GUI folder UI-TARS-1.5 implementation with the mm_agents architecture. + """ + + def __init__( + self, + # Model settings + model: str, + + # Generation settings + max_tokens: int, + top_p: Optional[float], + temperature: float, + + # History settings + max_trajectory_length: Optional[int], + max_image_history_length: Optional[int], # UI-TARS uses history-5 logic + + # Prompt settings + screenshot_pyautogui_prompt: str = "uitars_v1", + + # Parse settings + which_parsed_actions: str = "all", + + # Outside infos + max_steps: int = 100, + + # UI-TARS specific settings + use_thinking: bool = True, + language: str = "Chinese", + ): + """ + Initialize UI-TARS Agent. + + Args: + model: Model name, defaults to doubao-1-5-thinking-vision-pro-250428 + api_key: API key for the model service + base_url: Base URL for the API service + max_tokens: Maximum tokens to generate + top_p: Top-p sampling parameter + temperature: Temperature for sampling + max_trajectory_length: Maximum trajectory history length + max_image_history_length: Maximum image history length (UI-TARS uses 5) + screenshot_pyautogui_prompt: Prompt version + which_parsed_actions: Which actions to parse + max_steps: Maximum steps for the agent + use_thinking: Whether to use thinking mode + language: Language for responses + openai_client: OpenAI client instance + """ + + self.model = model + self.max_trajectory_length = max_trajectory_length + self.logger = logger + self.language = language + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + self.system_prompt = COMPUTER_USE_DOUBAO + + + self.action_parse_res_factor = 1000 + self.model_type = "doubao" + self.history_n = 5 + self.top_p = top_p + self.temperature = temperature + self.max_tokens = max_tokens + self.platform = "ubuntu" + + def reset(self, _logger=None): + global logger + logger = _logger if _logger is not None else logging.getLogger("desktopenv.agent") + + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + def pretty_print_messages(self, messages): + """Pretty print messages while hiding base64 encoded images.""" + def format_message(msg): + if not isinstance(msg, dict): + return str(msg) + + formatted = {} + for key, value in msg.items(): + if key == "content": + if isinstance(value, list): + formatted_content = [] + for item in value: + if isinstance(item, dict) and "type" in item: + if item["type"] == "image_url" and "image_url" in item: + # Replace base64 image with placeholder + formatted_content.append({ + "type": "image_url", + "image_url": {"url": "[BASE64_IMAGE_DATA]"} + }) + else: + formatted_content.append(item) + else: + formatted_content.append(item) + formatted[key] = formatted_content + else: + formatted[key] = value + else: + formatted[key] = value + return formatted + + if isinstance(messages, list): + return [format_message(msg) for msg in messages] + return format_message(messages) + + + def inference_with_thinking(self, messages): + api_key = os.environ['DOUBAO_API_KEY'] + api_url = os.environ['DOUBAO_API_URL'] + headers = { + 'Authorization': f'Bearer {api_key}', + 'Content-Type': 'application/json' + } + data = { + "model": self.model, + "messages": messages, + "thinking": {"type": "enabled"}, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + } + + response = requests.post(api_url, headers=headers, json=data) + + print(response.json()["choices"][0]) + if response.status_code == 200: + return response.json()["choices"][0]["message"]["content"] + else: + return { + "error": f"Request failed with status code {response.status_code}", + "details": response.text + } + + def predict(self, task_instruction: str, obs: dict) -> Tuple[str, List]: + """Predict the next action based on the current observation.""" + + self.task_instruction = task_instruction + + assert len(self.observations) == len(self.actions) and len(self.actions) == len( + self.thoughts + ), "The number of observations and actions should be the same." + + # Convert binary screenshot to base64 if needed + screenshot = obs["screenshot"] + if isinstance(screenshot, bytes): + screenshot = base64.b64encode(screenshot).decode('utf-8') + + self.history_images.append(screenshot) + + self.observations.append( + {"screenshot": screenshot, "accessibility_tree": None} + ) + + if len(self.history_images) > self.history_n: + self.history_images = self.history_images[-self.history_n:] + + images = self.history_images + + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": self.system_prompt.format( + instruction=task_instruction, + language=self.language + )}] + } + ] + + image_num = 0 + if len(self.history_responses) > 0: + for history_idx, history_response in enumerate(self.history_responses): + # send at most history_n images to the model + if history_idx + self.history_n > len(self.history_responses): + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}] + }) + image_num += 1 + + messages.append({ + "role": "assistant", + "content": history_response + }) + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}] + }) + image_num += 1 + else: + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{images[image_num]}"}}] + }) + image_num += 1 + + try_times = 3 + origin_resized_height = 1080 + origin_resized_width = 1920 + prediction = None + while True: + if try_times <= 0: + self.logger.error(f"Reach max retry times to fetch response from client, as error flag.") + return prediction, ["FAIL"] + try: + logger.info(f"Messages: {self.pretty_print_messages(messages[-1])}") + prediction = self.inference_with_thinking(messages) + + except Exception as e: + self.logger.error(f"Error when fetching response from client, with error:\n{e}") + prediction = None + try_times -= 1 + + try: + parsed_dict = parse_action_to_structure_output(prediction, self.action_parse_res_factor, origin_resized_height, origin_resized_width, self.model_type) + parsed_pyautogui_code = parsing_response_to_pyautogui_code(parsed_dict, origin_resized_height, origin_resized_width, platform=self.platform) + break + except Exception as e: + self.logger.error(f"Error when parsing response from client, with error:\n{e}") + prediction = None + try_times -= 1 + + self.history_responses.append(prediction) + + try: + parsed_dict = parse_action_to_structure_output(prediction, self.action_parse_res_factor, origin_resized_height, origin_resized_width, self.model_type) + parsed_pyautogui_code = parsing_response_to_pyautogui_code(parsed_dict, origin_resized_height, origin_resized_width, platform=self.platform) + + except Exception as e: + self.logger.error(f"Parsing action error: {prediction}, with error:\n{e}") + return prediction, ["FAIL"] + + thoughts = "" + for parsed_response in parsed_dict: + if "thought" in parsed_response and parsed_response["thought"]: + thoughts += parsed_response["thought"] + if thoughts: + self.thoughts.append(thoughts) + for parsed_response in parsed_dict: + if "action_type" in parsed_response: + if parsed_response["action_type"] == FINISH_WORD: + self.actions.append(["DONE"]) + + return prediction, ["DONE"] + + elif parsed_response["action_type"] == WAIT_WORD: + self.actions.append(["WAIT"]) + + return prediction, ["WAIT"] + + elif parsed_response["action_type"] == ENV_FAIL_WORD: + self.actions.append(["FAIL"]) + return prediction, ["FAIL"] + + + self.actions.append([parsed_pyautogui_code]) + + + return prediction, [parsed_pyautogui_code] + \ No newline at end of file diff --git a/run_multienv_uitars15.py b/run_multienv_uitars15.py new file mode 100644 index 0000000..0884dc2 --- /dev/null +++ b/run_multienv_uitars15.py @@ -0,0 +1,531 @@ +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.uitars15_agent import UITarsAgent + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=0) + parser.add_argument("--max_steps", type=int, default=15) + + # evaluation config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428") + parser.add_argument("--temperature", type=float, default=0) + parser.add_argument("--top_p", type=float, default=None) + parser.add_argument("--max_tokens", type=int, default=3000) + + # OpenCUAagent config + parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.") + parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.") + parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.") + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + + +def run_env_tasks(env_idx: int, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + # Each process has its own list of active environments + active_environments = [] + env = None + + # Setup signal handlers for this process too + signal.signal(signal.SIGINT, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) + signal.signal(signal.SIGTERM, lambda signum, frame: process_signal_handler(signum, frame, env_idx)) + + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + agent = UITarsAgent( + model=args.model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + + max_trajectory_length=args.max_trajectory_length, + max_image_history_length=args.max_image_history_length, + use_thinking=True, + language=args.language, + ) + + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + try: + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + finally: + # This ensures the environment is closed even if there's an exception + logger.info(f"Process {env_idx + 1} cleaning up environment...") + try: + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error during environment cleanup: {e}") + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal + os.kill(p.pid, signal.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, env_tasks in enumerate(distributed_tasks): + p = Process( + target=run_env_tasks, + args=(env_idx, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + logger.info(f"Started process {p.name} with PID {p.pid}") + + try: + # Wait for all processes to complete + for p in processes: + p.join() + logger.info(f"Process {p.name} completed") + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + # Let the signal handler do the cleanup + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + # Ensure cleanup happens + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}")