diff --git a/desktop_env/providers/aws/provider.py b/desktop_env/providers/aws/provider.py index d2c034e..002a9b9 100644 --- a/desktop_env/providers/aws/provider.py +++ b/desktop_env/providers/aws/provider.py @@ -77,7 +77,8 @@ class AWSProvider(Provider): else: logger.warning("No public IP address available for VNC access") - return private_ip_address + #return private_ip_address + return public_ip_address return '' # Return an empty string if no IP address is found except ClientError as e: logger.error(f"Failed to retrieve IP address for the instance {path_to_vm}: {str(e)}") diff --git a/lib_run_single.py b/lib_run_single.py index 91a0163..2c21ad0 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -44,6 +44,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl "step_num": step_idx + 1, "action_timestamp": action_timestamp, "action": action, + "response": response, "reward": reward, "done": done, "info": info, diff --git a/mm_agents/uitars15_v1.py b/mm_agents/uitars15_v1.py new file mode 100644 index 0000000..c27d95a --- /dev/null +++ b/mm_agents/uitars15_v1.py @@ -0,0 +1,956 @@ +import ast +import base64 +from openai import OpenAI +import math +import re +import xml.etree.ElementTree as ET +from io import BytesIO +from typing import Dict, List +import numpy as np +import base64 +from loguru import logger +import os +import re +from io import BytesIO +from typing import Dict, List +from PIL import Image +from mm_agents.accessibility_tree_wrap.heuristic_retrieve import ( + filter_nodes, +) + +UITARS_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +""" + +UITARS_CALL_USR_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. +""" + +UITARS_NORMAL_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. +""" + +UITARS_USR_PROMPT_NOTHOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. +## Output Format +``` +Action: ... +``` +## Action Space +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished() +call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help. +## User Instruction +{instruction} +""" + +UITARS_USR_PROMPT_THOUGHT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space +{action_space} + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. + +## User Instruction +{instruction} +""" + +FINISH_WORD = "finished" +WAIT_WORD = "wait" +ENV_FAIL_WORD = "error_env" +CALL_USER = "call_user" + +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +# 定义一个函数来解析每个 action +def parse_action(action_str): + try: + # 解析字符串为 AST 节点 + node = ast.parse(action_str, mode='eval') + + # 确保节点是一个表达式 + if not isinstance(node, ast.Expression): + raise ValueError("Not an expression") + + # 获取表达式的主体 + call = node.body + + # 确保主体是一个函数调用 + if not isinstance(call, ast.Call): + raise ValueError("Not a function call") + + # 获取函数名 + if isinstance(call.func, ast.Name): + func_name = call.func.id + elif isinstance(call.func, ast.Attribute): + func_name = call.func.attr + else: + func_name = None + + # 获取关键字参数 + kwargs = {} + for kw in call.keywords: + key = kw.arg + # 处理不同类型的值,这里假设都是常量 + if isinstance(kw.value, ast.Constant): + value = kw.value.value + elif isinstance(kw.value, ast.Str): # 兼容旧版本 Python + value = kw.value.s + else: + value = None + kwargs[key] = value + + return { + 'function': func_name, + 'args': kwargs + } + + except Exception as e: + print(f"Failed to parse action '{action_str}': {e}") + return None + +def escape_single_quotes(text): + # 匹配未转义的单引号(不匹配 \\') + pattern = r"(? int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def linear_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + if width * height > max_pixels: + """ + 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 + """ + resize_factor = math.sqrt(max_pixels / (width * height)) + width, height = int(width * resize_factor), int(height * resize_factor) + if width * height < min_pixels: + resize_factor = math.sqrt(min_pixels / (width * height)) + width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor) + + return height, width + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + +def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28): + text = text.strip() + if model_type == "qwen25vl": + smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) + + # 正则表达式匹配 Action 字符串 + if text.startswith("Thought:"): + thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" + thought_hint = "Thought: " + elif text.startswith("Reflection:"): + thought_pattern = r"Reflection: (.+?)Action_Summary: (.+?)(?=\s*Action:|$)" + thought_hint = "Reflection: " + elif text.startswith("Action_Summary:"): + thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)" + thought_hint = "Action_Summary: " + else: + thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" + thought_hint = "Thought: " + reflection, thought = None, None + thought_match = re.search(thought_pattern, text, re.DOTALL) + if thought_match: + if len(thought_match.groups()) == 1: + thought = thought_match.group(1).strip() + elif len(thought_match.groups()) == 2: + thought = thought_match.group(2).strip() + reflection = thought_match.group(1).strip() + assert "Action:" in text + action_str = text.split("Action:")[-1] + + tmp_all_action = action_str.split("\n\n") + all_action = [] + for action_str in tmp_all_action: + if "type(content" in action_str: + # 正则表达式匹配 content 中的字符串并转义单引号 + def escape_quotes(match): + content = match.group(1) # 获取 content 的值 + return content + + # 使用正则表达式进行替换 + pattern = r"type\(content='(.*?)'\)" # 匹配 type(content='...') + content = re.sub(pattern, escape_quotes, action_str) + + # 处理字符串 + action_str = escape_single_quotes(content) + action_str = "type(content='" + action_str + "')" + all_action.append(action_str) + + parsed_actions = [parse_action(action.replace("\n","\\n").lstrip()) for action in all_action] + actions = [] + for action_instance, raw_str in zip(parsed_actions, all_action): + if action_instance == None: + print(f"Action can't parse: {raw_str}") + raise ValueError(f"Action can't parse: {raw_str}") + action_type = action_instance["function"] + params = action_instance["args"] + + # import pdb; pdb.set_trace() + action_inputs = {} + for param_name, param in params.items(): + if param == "": continue + param = param.lstrip() # 去掉引号和多余的空格 + # 处理start_box或者end_box参数格式 'x1 y1 x2 y2' + action_inputs[param_name.strip()] = param + + if "start_box" in param_name or "end_box" in param_name: + ori_box = param + # Remove parentheses and split the string by commas + numbers = ori_box.replace("(", "").replace(")", "").split(",") + + # Convert to float and scale by 1000 + # Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates + if model_type == "qwen25vl": + float_numbers = [] + for num_idx, num in enumerate(numbers): + num = float(num) + if (num_idx + 1) % 2 == 0: + float_numbers.append(float(num/smart_resize_height)) + else: + float_numbers.append(float(num/smart_resize_width)) + else: + float_numbers = [float(num) / factor for num in numbers] + + if len(float_numbers) == 2: + float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] + action_inputs[param_name.strip()] = str(float_numbers) + + # import pdb; pdb.set_trace() + actions.append({ + "reflection": reflection, + "thought": thought, + "action_type": action_type, + "action_inputs": action_inputs, + "text": text + }) + return actions + +def parsing_response_to_pyautogui_code(responses, image_height: int, image_width:int, input_swap:bool=True) -> str: + ''' + 将M模型的输出解析为OSWorld中的action,生成pyautogui代码字符串 + 参数: + response: 包含模型输出的字典,结构类似于: + { + "action_type": "hotkey", + "action_inputs": { + "hotkey": "v ctrl", + "start_box": None, + "end_box": None + } + } + 返回: + 生成的pyautogui代码字符串 + ''' + + pyautogui_code = f"import pyautogui\nimport time\n" + if isinstance(responses, dict): + responses = [responses] + for response_id, response in enumerate(responses): + if "observation" in response: + observation = response["observation"] + else: + observation = "" + + if "thought" in response: + thought = response["thought"] + else: + thought = "" + + if response_id == 0: + pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" + else: + pyautogui_code += f"\ntime.sleep(1)\n" + + action_dict = response + action_type = action_dict.get("action_type") + action_inputs = action_dict.get("action_inputs", {}) + + if action_type == "hotkey": + # Parsing hotkey action + if "key" in action_inputs: + hotkey = action_inputs.get("key", "") + else: + hotkey = action_inputs.get("hotkey", "") + + if hotkey == "arrowleft": + hotkey = "left" + + elif hotkey == "arrowright": + hotkey = "right" + + elif hotkey == "arrowup": + hotkey = "up" + + elif hotkey == "arrowdown": + hotkey = "down" + + if hotkey: + # Handle other hotkeys + keys = hotkey.split() # Split the keys by space + convert_keys = [] + for key in keys: + if key == "space": + key = ' ' + convert_keys.append(key) + pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})" + + elif action_type == "press": + # Parsing press action + if "key" in action_inputs: + key_to_press = action_inputs.get("key", "") + else: + key_to_press = action_inputs.get("press", "") + + if hotkey == "arrowleft": + hotkey = "left" + + elif hotkey == "arrowright": + hotkey = "right" + + elif hotkey == "arrowup": + hotkey = "up" + + elif hotkey == "arrowdown": + hotkey = "down" + + elif hotkey == "space": + hotkey = " " + + if key_to_press: + # Simulate pressing a single key + pyautogui_code += f"\npyautogui.press({repr(key_to_press)})" + + elif action_type == "keyup": + key_to_up = action_inputs.get("key", "") + pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})" + + elif action_type == "keydown": + key_to_down = action_inputs.get("key", "") + pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})" + + elif action_type == "type": + # Parsing typing action using clipboard + content = action_inputs.get("content", "") + content = escape_single_quotes(content) + stripped_content = content + if content.endswith("\n") or content.endswith("\\n"): + stripped_content = stripped_content.rstrip("\\n").rstrip("\n") + if content: + if input_swap: + pyautogui_code += f"\nimport pyperclip" + pyautogui_code += f"\npyperclip.copy('{stripped_content}')" + pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + else: + pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" + pyautogui_code += f"\ntime.sleep(0.5)\n" + if content.endswith("\n") or content.endswith("\\n"): + pyautogui_code += f"\npyautogui.press('enter')" + + + elif action_type in ["drag", "select"]: + # Parsing drag or select action based on start and end_boxes + start_box = action_inputs.get("start_box") + end_box = action_inputs.get("end_box") + if start_box and end_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + sx = round(float((x1 + x2) / 2) * image_width, 3) + sy = round(float((y1 + y2) / 2) * image_height, 3) + x1, y1, x2, y2 = eval(end_box) # Assuming box is in [x1, y1, x2, y2] + ex = round(float((x1 + x2) / 2) * image_width, 3) + ey = round(float((y1 + y2) / 2) * image_height, 3) + pyautogui_code += ( + f"\npyautogui.moveTo({sx}, {sy})\n" + f"\npyautogui.dragTo({ex}, {ey}, duration=1.0)\n" + ) + + elif action_type == "scroll": + # Parsing scroll action + start_box = action_inputs.get("start_box") + if start_box: + x1, y1, x2, y2 = eval(start_box) # Assuming box is in [x1, y1, x2, y2] + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + + # # 先点对应区域,再滚动 + # pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + else: + x = None + y = None + direction = action_inputs.get("direction", "") + + if x == None: + if "up" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(5)" + elif "down" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(-5)" + else: + if "up" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(5, x={x}, y={y})" + elif "down" in direction.lower(): + pyautogui_code += f"\npyautogui.scroll(-5, x={x}, y={y})" + + elif action_type in ["click", "left_single", "left_double", "right_single", "hover"]: + # Parsing mouse click actions + start_box = action_inputs.get("start_box") + start_box = str(start_box) + if start_box: + start_box = eval(start_box) + if len(start_box) == 4: + x1, y1, x2, y2 = start_box # Assuming box is in [x1, y1, x2, y2] + elif len(start_box) == 2: + x1, y1 = start_box + x2 = x1 + y2 = y1 + x = round(float((x1 + x2) / 2) * image_width, 3) + y = round(float((y1 + y2) / 2) * image_height, 3) + if action_type == "left_single" or action_type == "click": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='left')" + elif action_type == "left_double": + pyautogui_code += f"\npyautogui.doubleClick({x}, {y}, button='left')" + elif action_type == "right_single": + pyautogui_code += f"\npyautogui.click({x}, {y}, button='right')" + elif action_type == "hover": + pyautogui_code += f"\npyautogui.moveTo({x}, {y})" + + elif action_type in ["finished"]: + pyautogui_code = f"DONE" + + else: + pyautogui_code += f"\n# Unrecognized action type: {action_type}" + + return pyautogui_code + +def add_box_token(input_string): + # Step 1: Split the string into individual actions + if "Action: " in input_string and "start_box=" in input_string: + suffix = input_string.split("Action: ")[0] + "Action: " + actions = input_string.split("Action: ")[1:] + processed_actions = [] + for action in actions: + action = action.strip() + # Step 2: Extract coordinates (start_box or end_box) using regex + coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action) + + updated_action = action # Start with the original action + for coord_type, x, y in coordinates: + # Convert x and y to integers + updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'") + processed_actions.append(updated_action) + + # Step 5: Reconstruct the final string + final_string = suffix + "\n\n".join(processed_actions) + else: + final_string = input_string + return final_string + +def pil_to_base64(image): + buffer = BytesIO() + image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式 + return base64.b64encode(buffer.getvalue()).decode("utf-8") + +def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): + + if platform == "ubuntu": + _attributes_ns = attributes_ns_ubuntu + _state_ns = state_ns_ubuntu + _component_ns = component_ns_ubuntu + _value_ns = value_ns_ubuntu + elif platform == "windows": + _attributes_ns = attributes_ns_windows + _state_ns = state_ns_windows + _component_ns = component_ns_windows + _value_ns = value_ns_windows + else: + raise ValueError("Invalid platform, must be 'ubuntu' or 'windows'") + + filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform) + linearized_accessibility_tree = [ + "tag\tname\ttext\tclass\tdescription\tposition (top-left x&y)\tsize (w&h)" + ] + + # Linearize the accessibility tree nodes into a table format + for node in filtered_nodes: + if node.text: + text = ( + node.text + if '"' not in node.text + else '"{:}"'.format(node.text.replace('"', '""')) + ) + + elif node.get("{{{:}}}class".format(class_ns_windows), "").endswith( + "EditWrapper" + ) and node.get("{{{:}}}value".format(_value_ns)): + node_text = node.get("{{{:}}}value".format(_value_ns), "") + text = ( + node_text + if '"' not in node_text + else '"{:}"'.format(node_text.replace('"', '""')) + ) + else: + text = '""' + + linearized_accessibility_tree.append( + "{:}\t{:}\t{:}\t{:}\t{:}\t{:}\t{:}".format( + node.tag, + node.get("name", ""), + text, + ( + node.get("{{{:}}}class".format(_attributes_ns), "") + if platform == "ubuntu" + else node.get("{{{:}}}class".format(class_ns_windows), "") + ), + node.get("{{{:}}}description".format(_attributes_ns), ""), + node.get("{{{:}}}screencoord".format(_component_ns), ""), + node.get("{{{:}}}size".format(_component_ns), ""), + ) + ) + + return "\n".join(linearized_accessibility_tree) + +def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): + # enc = tiktoken.encoding_for_model("gpt-4") + # tokens = enc.encode(linearized_accessibility_tree) + # if len(tokens) > max_tokens: + # linearized_accessibility_tree = enc.decode(tokens[:max_tokens]) + # linearized_accessibility_tree += "[...]\n" + return linearized_accessibility_tree + + +class UITARSAgent: + def __init__( + self, + model: str, + runtime_conf: Dict, + platform="ubuntu", + action_space="pyautogui", + observation_type="screenshot", + # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] + max_trajectory_length=50, + a11y_tree_max_tokens=10000, + model_type="qwen25vl", + **kwargs + ): + self.model = model + self.platform = platform + self.action_space = action_space + self.observation_type = observation_type + self.max_trajectory_length = max_trajectory_length + self.a11y_tree_max_tokens = a11y_tree_max_tokens + self.model_type = model_type + self.runtime_conf = runtime_conf + self.temperature = self.runtime_conf["temperature"] + self.top_k = self.runtime_conf["top_k"] + self.top_p = self.runtime_conf["top_p"] + self.max_tokens = self.runtime_conf["max_tokens"] + self.infer_mode = self.runtime_conf["infer_mode"] + self.prompt_style = self.runtime_conf["prompt_style"] + self.input_swap = self.runtime_conf["input_swap"] + self.language = self.runtime_conf["language"] + self.max_pixels = self.runtime_conf["max_pixels"] + self.min_pixels = self.runtime_conf["min_pixels"] + self.callusr_tolerance = self.runtime_conf["callusr_tolerance"] + self.vlm = OpenAI( + base_url=os.environ['DOUBAO_API_URL'], + api_key=os.environ['DOUBAO_API_KEY'], + ) + + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + self.prompt_action_space = UITARS_ACTION_SPACE + self.action_parse_res_factor = 1000 + if self.infer_mode == "qwen2vl_user": + self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE + elif self.infer_mode == "qwen25vl_normal": + self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE + + self.prompt_template = UITARS_USR_PROMPT_THOUGHT + + if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal": + self.prompt_template = UITARS_USR_PROMPT_THOUGHT + + elif self.prompt_style == "qwen2vl_no_thought": + self.prompt_template = UITARS_USR_PROMPT_NOTHOUGHT + + + if "history_n" in self.runtime_conf: + self.history_n = self.runtime_conf["history_n"] + else: + self.history_n = 5 + + self.cur_callusr_count = 0 + + def reset(self, runtime_logger=None): + self.thoughts = [] + self.actions = [] + self.observations = [] + self.history_images = [] + self.history_responses = [] + + + def predict( + self, instruction: str, obs: Dict, last_action_after_obs: Dict = None + ) -> List: + """ + Predict the next action(s) based on the current observation. + """ + + # Append trajectory + # print(len(self.observations), len(self.actions), len(self.actions)) + assert len(self.observations) == len(self.actions) and len(self.actions) == len( + self.thoughts + ), "The number of observations and actions should be the same." + + if len(self.observations) > self.max_trajectory_length: + if self.max_trajectory_length == 0: + _observations = [] + _actions = [] + _thoughts = [] + else: + _observations = self.observations[-self.max_trajectory_length :] + _actions = self.actions[-self.max_trajectory_length :] + _thoughts = self.thoughts[-self.max_trajectory_length :] + else: + _observations = self.observations + _actions = self.actions + _thoughts = self.thoughts + + + self.history_images.append(obs["screenshot"]) + + if self.observation_type in ["screenshot", "screenshot_a11y_tree"]: + base64_image = obs["screenshot"] + try: + linearized_accessibility_tree = ( + linearize_accessibility_tree( + accessibility_tree=obs["accessibility_tree"], + platform=self.platform, + ) + if self.observation_type == "screenshot_a11y_tree" + else None + ) + except: + linearized_accessibility_tree = None + # logger.debug("LINEAR AT: %s", linearized_accessibility_tree) + + if linearized_accessibility_tree: + linearized_accessibility_tree = trim_accessibility_tree( + linearized_accessibility_tree, self.a11y_tree_max_tokens + ) + + if self.observation_type == "screenshot_a11y_tree": + self.observations.append( + { + "screenshot": base64_image, + "accessibility_tree": linearized_accessibility_tree, + } + ) + else: + self.observations.append( + {"screenshot": base64_image, "accessibility_tree": None} + ) + + else: + raise ValueError( + "Invalid observation_type type: " + self.observation_type + ) # 1}}} + + if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal": + user_prompt = self.prompt_template.format( + instruction=instruction, + action_space=self.prompt_action_space, + language=self.language + ) + elif self.infer_mode == "qwen2vl_no_thought": + user_prompt = self.prompt_template.format( + instruction=instruction + ) + + if len(self.history_images) > self.history_n: + self.history_images = self.history_images[-self.history_n:] + + messages, images = [], [] + if isinstance(self.history_images, bytes): + self.history_images = [self.history_images] + elif isinstance(self.history_images, np.ndarray): + self.history_images = list(self.history_images) + elif isinstance(self.history_images, list): + pass + else: + raise TypeError(f"Unidentified images type: {type(self.history_images)}") + + for turn, image in enumerate(self.history_images): + if len(images) >= self.history_n: + break + try: + image = Image.open(BytesIO(image)) + except Exception as e: + raise RuntimeError(f"Error opening image: {e}") + + if image.width * image.height > self.max_pixels: + """ + 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 + """ + resize_factor = math.sqrt(self.max_pixels / (image.width * image.height)) + width, height = int(image.width * resize_factor), int(image.height * resize_factor) + image = image.resize((width, height)) + if image.width * image.height < self.min_pixels: + resize_factor = math.sqrt(self.min_pixels / (image.width * image.height)) + width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + images.append(image) + + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant."}] + }, + { + "role": "user", + "content": [{"type": "text", "text": user_prompt}] + } + ] + + image_num = 0 + if len(self.history_responses) > 0: + for history_idx, history_response in enumerate(self.history_responses): + # send at most history_n images to the model + if history_idx + self.history_n > len(self.history_responses): + + cur_image = images[image_num] + encoded_string = pil_to_base64(cur_image) + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}] + }) + image_num += 1 + + messages.append({ + "role": "assistant", + "content": [{"type": "text", "text": add_box_token(history_response)}] + }) + + cur_image = images[image_num] + encoded_string = pil_to_base64(cur_image) + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}] + }) + image_num += 1 + + else: + cur_image = images[image_num] + encoded_string = pil_to_base64(cur_image) + messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_string}"}}] + }) + image_num += 1 + + try_times = 3 + origin_resized_height = images[-1].height + origin_resized_width = images[-1].width + temperature = self.temperature + top_k = self.top_k + while True: + if try_times <= 0: + print(f"Reach max retry times to fetch response from client, as error flag.") + return "client error", ["DONE"] + try: + response = self.vlm.chat.completions.create( + model=self.model, + messages=messages, + frequency_penalty=1, + max_tokens=self.max_tokens, + temperature=temperature, + top_p=self.top_p + ) + print("*" * 20) + print("Response:") + print(response.choices[0].message.content) + print("*" * 20) + prediction = response.choices[0].message.content.strip() + + except Exception as e: + logger.exception(f"Error when fetching response from client: {e}") + prediction = None + try_times -= 1 + + try: + parsed_responses = parse_action_to_structure_output( + prediction, + self.action_parse_res_factor, + origin_resized_height, + origin_resized_width, + self.model_type, + self.max_pixels, + self.min_pixels + ) + break + except Exception as e: + print(f"Error when parsing response from client: {e}") + # If fail to parse the model response, we use sampling parameters to avoid it + prediction = None + try_times -= 1 + temperature = 1 + top_k = -1 + + if prediction is None: + return "client error", ["DONE"] + + self.history_responses.append(prediction) + self.thoughts.append(prediction) + + try: + parsed_responses = parse_action_to_structure_output( + prediction, + self.action_parse_res_factor, + origin_resized_height, + origin_resized_width, + self.model_type, + self.max_pixels, + self.min_pixels + ) + except Exception as e: + print(f"Parsing action error: {prediction}, with error:\n{e}") + return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"] + + actions = [] + last_image = Image.open(BytesIO(self.history_images[-1])) + obs_image_height = last_image.height + obs_image_width = last_image.width + for parsed_response in parsed_responses: + if "action_type" in parsed_response: + + if parsed_response["action_type"] == FINISH_WORD: + self.actions.append(actions) + + return prediction, ["DONE"] + + elif parsed_response["action_type"] == WAIT_WORD: + self.actions.append(actions) + return prediction, ["WAIT"] + + elif parsed_response["action_type"] == ENV_FAIL_WORD: + self.actions.append(actions) + return prediction, ["FAIL"] + + elif parsed_response["action_type"] == CALL_USER: + if self.callusr_tolerance > self.cur_callusr_count: + self.actions.append(actions) + self.cur_callusr_count += 1 + return prediction, ["WAIT"] + else: + self.actions.append(actions) + return prediction, ["FAIL"] + + pyautogui_code = parsing_response_to_pyautogui_code( + parsed_response, + obs_image_height, + obs_image_width, + self.input_swap + ) + actions.append(pyautogui_code) + + self.actions.append(actions) + + if len(self.history_responses) >= self.max_trajectory_length: + # Default to FAIL if exceed max steps + actions = ["FAIL"] + + return prediction, actions diff --git a/mm_agents/uitars15_agent.py b/mm_agents/uitars15_v2.py similarity index 99% rename from mm_agents/uitars15_agent.py rename to mm_agents/uitars15_v2.py index 34280cd..199cef4 100644 --- a/mm_agents/uitars15_agent.py +++ b/mm_agents/uitars15_v2.py @@ -613,7 +613,7 @@ class UITarsAgent: self, # Model settings model: str, - + model_type: str, # Generation settings max_tokens: int, top_p: Optional[float], @@ -672,7 +672,7 @@ class UITarsAgent: self.system_prompt = COMPUTER_USE_NO_THINKING self.action_parse_res_factor = 1000 - self.model_type = "doubao" + self.model_type = model_type self.history_n = 5 self.top_p = top_p self.temperature = temperature diff --git a/mm_agents/uitars_agent.py b/mm_agents/uitars_agent.py index 245a5c3..a36c628 100644 --- a/mm_agents/uitars_agent.py +++ b/mm_agents/uitars_agent.py @@ -6,7 +6,7 @@ import re import xml.etree.ElementTree as ET from io import BytesIO from typing import Dict, List - +import os import backoff import numpy as np from PIL import Image @@ -28,22 +28,16 @@ from mm_agents.prompts import ( UITARS_CALL_USR_ACTION_SPACE, UITARS_USR_PROMPT_NOTHOUGHT, UITARS_USR_PROMPT_THOUGHT, - UITARS_NORMAL_ACTION_SPACE ) -logger = logging.getLogger("desktopenv.agent") +from loguru import logger FINISH_WORD = "finished" WAIT_WORD = "wait" ENV_FAIL_WORD = "error_env" CALL_USER = "call_user" -IMAGE_FACTOR = 28 -MIN_PIXELS = 100 * 28 * 28 -MAX_PIXELS = 16384 * 28 * 28 -MAX_RATIO = 200 - pure_text_settings = ["a11y_tree"] attributes_ns_ubuntu = "https://accessibility.windows.example.org/ns/attributes" @@ -109,68 +103,8 @@ def escape_single_quotes(text): pattern = r"(? int: - """Returns the closest integer to 'number' that is divisible by 'factor'.""" - return round(number / factor) * factor - - -def ceil_by_factor(number: int, factor: int) -> int: - """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" - return math.ceil(number / factor) * factor - - -def floor_by_factor(number: int, factor: int) -> int: - """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" - return math.floor(number / factor) * factor - -def linear_resize( - height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS -) -> tuple[int, int]: - if width * height > max_pixels: - """ - 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 - """ - resize_factor = math.sqrt(max_pixels / (width * height)) - width, height = int(width * resize_factor), int(height * resize_factor) - if width * height < min_pixels: - resize_factor = math.sqrt(min_pixels / (width * height)) - width, height = math.ceil(width * resize_factor), math.ceil(height * resize_factor) - - return height, width - -def smart_resize( - height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS -) -> tuple[int, int]: - """ - Rescales the image so that the following conditions are met: - - 1. Both dimensions (height and width) are divisible by 'factor'. - - 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. - - 3. The aspect ratio of the image is maintained as closely as possible. - """ - if max(height, width) / min(height, width) > MAX_RATIO: - raise ValueError( - f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" - ) - h_bar = max(factor, round_by_factor(height, factor)) - w_bar = max(factor, round_by_factor(width, factor)) - if h_bar * w_bar > max_pixels: - beta = math.sqrt((height * width) / max_pixels) - h_bar = floor_by_factor(height / beta, factor) - w_bar = floor_by_factor(width / beta, factor) - elif h_bar * w_bar < min_pixels: - beta = math.sqrt(min_pixels / (height * width)) - h_bar = ceil_by_factor(height * beta, factor) - w_bar = ceil_by_factor(width * beta, factor) - return h_bar, w_bar - -def parse_action_to_structure_output(text, factor, origin_resized_height, origin_resized_width, model_type, max_pixels=16384*28*28, min_pixels=100*28*28): +def parse_action_qwen2vl(text, factor, image_height, image_width): text = text.strip() - if model_type == "qwen25vl": - smart_resize_height, smart_resize_width = smart_resize(origin_resized_height, origin_resized_width, factor=IMAGE_FACTOR, min_pixels=min_pixels, max_pixels=max_pixels) - # 正则表达式匹配 Action 字符串 if text.startswith("Thought:"): thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" @@ -182,8 +116,10 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin thought_pattern = r"Action_Summary: (.+?)(?=\s*Action:|$)" thought_hint = "Action_Summary: " else: - thought_pattern = r"Thought: (.+?)(?=\s*Action:|$)" - thought_hint = "Thought: " + # 修复:当没有明确的"Thought:"标识时,提取Action:之前的所有内容作为思考 + thought_pattern = r"(.+?)(?=\s*Action:|$)" + thought_hint = "" + reflection, thought = None, None thought_match = re.search(thought_pattern, text, re.DOTALL) if thought_match: @@ -218,7 +154,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin for action_instance, raw_str in zip(parsed_actions, all_action): if action_instance == None: print(f"Action can't parse: {raw_str}") - raise ValueError(f"Action can't parse: {raw_str}") + continue action_type = action_instance["function"] params = action_instance["args"] @@ -236,18 +172,7 @@ def parse_action_to_structure_output(text, factor, origin_resized_height, origin numbers = ori_box.replace("(", "").replace(")", "").split(",") # Convert to float and scale by 1000 - # Qwen2.5vl output absolute coordinates, qwen2vl output relative coordinates - if model_type == "qwen25vl": - float_numbers = [] - for num_idx, num in enumerate(numbers): - num = float(num) - if (num_idx + 1) % 2 == 0: - float_numbers.append(float(num/smart_resize_height)) - else: - float_numbers.append(float(num/smart_resize_width)) - else: - float_numbers = [float(num) / factor for num in numbers] - + float_numbers = [float(num) / factor for num in numbers] if len(float_numbers) == 2: float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] action_inputs[param_name.strip()] = str(float_numbers) @@ -296,7 +221,7 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width if response_id == 0: pyautogui_code += f"'''\nObservation:\n{observation}\n\nThought:\n{thought}\n'''\n" else: - pyautogui_code += f"\ntime.sleep(1)\n" + pyautogui_code += f"\ntime.sleep(3)\n" action_dict = response action_type = action_dict.get("action_type") @@ -309,79 +234,25 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width else: hotkey = action_inputs.get("hotkey", "") - if hotkey == "arrowleft": - hotkey = "left" - - elif hotkey == "arrowright": - hotkey = "right" - - elif hotkey == "arrowup": - hotkey = "up" - - elif hotkey == "arrowdown": - hotkey = "down" - if hotkey: # Handle other hotkeys keys = hotkey.split() # Split the keys by space - convert_keys = [] - for key in keys: - if key == "space": - key = ' ' - convert_keys.append(key) - pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in convert_keys])})" + pyautogui_code += f"\npyautogui.hotkey({', '.join([repr(k) for k in keys])})" - elif action_type == "press": - # Parsing press action - if "key" in action_inputs: - key_to_press = action_inputs.get("key", "") - else: - key_to_press = action_inputs.get("press", "") - - if hotkey == "arrowleft": - hotkey = "left" - - elif hotkey == "arrowright": - hotkey = "right" - - elif hotkey == "arrowup": - hotkey = "up" - - elif hotkey == "arrowdown": - hotkey = "down" - - elif hotkey == "space": - hotkey = " " - - if key_to_press: - # Simulate pressing a single key - pyautogui_code += f"\npyautogui.press({repr(key_to_press)})" - - elif action_type == "keyup": - key_to_up = action_inputs.get("key", "") - pyautogui_code += f"\npyautogui.keyUp({repr(key_to_up)})" - - elif action_type == "keydown": - key_to_down = action_inputs.get("key", "") - pyautogui_code += f"\npyautogui.keyDown({repr(key_to_down)})" - elif action_type == "type": # Parsing typing action using clipboard content = action_inputs.get("content", "") content = escape_single_quotes(content) - stripped_content = content - if content.endswith("\n") or content.endswith("\\n"): - stripped_content = stripped_content.rstrip("\\n").rstrip("\n") if content: if input_swap: pyautogui_code += f"\nimport pyperclip" - pyautogui_code += f"\npyperclip.copy('{stripped_content}')" + pyautogui_code += f"\npyperclip.copy('{content.strip()}')" pyautogui_code += f"\npyautogui.hotkey('ctrl', 'v')" pyautogui_code += f"\ntime.sleep(0.5)\n" if content.endswith("\n") or content.endswith("\\n"): pyautogui_code += f"\npyautogui.press('enter')" else: - pyautogui_code += f"\npyautogui.write('{stripped_content}', interval=0.1)" + pyautogui_code += f"\npyautogui.write('{content.strip()}', interval=0.1)" pyautogui_code += f"\ntime.sleep(0.5)\n" if content.endswith("\n") or content.endswith("\\n"): pyautogui_code += f"\npyautogui.press('enter')" @@ -460,29 +331,6 @@ def parsing_response_to_pyautogui_code(responses, image_height: int, image_width return pyautogui_code -def add_box_token(input_string): - # Step 1: Split the string into individual actions - if "Action: " in input_string and "start_box=" in input_string: - suffix = input_string.split("Action: ")[0] + "Action: " - actions = input_string.split("Action: ")[1:] - processed_actions = [] - for action in actions: - action = action.strip() - # Step 2: Extract coordinates (start_box or end_box) using regex - coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action) - - updated_action = action # Start with the original action - for coord_type, x, y in coordinates: - # Convert x and y to integers - updated_action = updated_action.replace(f"{coord_type}='({x},{y})'", f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'") - processed_actions.append(updated_action) - - # Step 5: Reconstruct the final string - final_string = suffix + "\n\n".join(processed_actions) - else: - final_string = input_string - return final_string - def pil_to_base64(image): buffer = BytesIO() image.save(buffer, format="PNG") # 你可以改成 "JPEG" 等格式 @@ -558,51 +406,48 @@ def trim_accessibility_tree(linearized_accessibility_tree, max_tokens): class UITARSAgent: def __init__( self, + model: str, platform="ubuntu", + max_tokens=1000, + top_p=0.9, + top_k=1.0, + temperature=0.0, action_space="pyautogui", - observation_type="screenshot", + observation_type="screenshot_a11y_tree", # observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"] max_trajectory_length=50, a11y_tree_max_tokens=10000, - model_type="qwen25vl", runtime_conf: dict = { - "infer_mode": "qwen25vl_normal", - "prompt_style": "qwen25vl_normal", + "infer_mode": "qwen2vl_user", + "prompt_style": "qwen2vl_user", "input_swap": True, "language": "Chinese", + "max_steps": 50, "history_n": 5, - "max_pixels": 16384*28*28, - "min_pixels": 100*28*28, - "callusr_tolerance": 3, - "temperature": 0.0, - "top_k": -1, - "top_p": 0.9, - "max_tokens": 500 - + "screen_height": 1080, + "screen_width": 1920 } ): + self.model = model self.platform = platform + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.temperature = temperature self.action_space = action_space self.observation_type = observation_type self.max_trajectory_length = max_trajectory_length self.a11y_tree_max_tokens = a11y_tree_max_tokens - self.model_type = model_type self.runtime_conf = runtime_conf self.vlm = OpenAI( - base_url="http://127.0.0.1:8000/v1", - api_key="empty", + base_url=os.environ['DOUBAO_API_URL'], + api_key=os.environ['DOUBAO_API_KEY'], ) # should replace with your UI-TARS server api - self.temperature = self.runtime_conf["temperature"] - self.top_k = self.runtime_conf["top_k"] - self.top_p = self.runtime_conf["top_p"] - self.max_tokens = self.runtime_conf["max_tokens"] self.infer_mode = self.runtime_conf["infer_mode"] self.prompt_style = self.runtime_conf["prompt_style"] self.input_swap = self.runtime_conf["input_swap"] self.language = self.runtime_conf["language"] - self.max_pixels = self.runtime_conf["max_pixels"] - self.min_pixels = self.runtime_conf["min_pixels"] - self.callusr_tolerance = self.runtime_conf["callusr_tolerance"] + self.max_steps = max_trajectory_length self.thoughts = [] self.actions = [] @@ -611,15 +456,14 @@ class UITARSAgent: self.history_responses = [] self.prompt_action_space = UITARS_ACTION_SPACE + self.customize_action_parser = parse_action_qwen2vl self.action_parse_res_factor = 1000 if self.infer_mode == "qwen2vl_user": self.prompt_action_space = UITARS_CALL_USR_ACTION_SPACE - elif self.infer_mode == "qwen25vl_normal": - self.prompt_action_space = UITARS_NORMAL_ACTION_SPACE self.prompt_template = UITARS_USR_PROMPT_THOUGHT - if self.prompt_style == "qwen2vl_user" or self.prompt_style == "qwen25vl_normal": + if self.prompt_style == "qwen2vl_user": self.prompt_template = UITARS_USR_PROMPT_THOUGHT elif self.prompt_style == "qwen2vl_no_thought": @@ -630,8 +474,6 @@ class UITARSAgent: self.history_n = self.runtime_conf["history_n"] else: self.history_n = 5 - - self.cur_callusr_count = 0 def predict( self, instruction: str, obs: Dict, last_action_after_obs: Dict = None @@ -660,18 +502,9 @@ class UITARSAgent: _actions = self.actions _thoughts = self.thoughts - for previous_obs, previous_action, previous_thought in zip( - _observations, _actions, _thoughts - ): - # {{{1 - if self.observation_type == "screenshot_a11y_tree": - _screenshot = previous_obs["screenshot"] - _linearized_accessibility_tree = previous_obs["accessibility_tree"] - - else: - raise ValueError( - "Invalid observation_type type: " + self.observation_type - ) # 1}}} + + if last_action_after_obs is not None and self.infer_mode == "double_image": + self.history_images.append(last_action_after_obs["screenshot"]) self.history_images.append(obs["screenshot"]) @@ -712,7 +545,7 @@ class UITARSAgent: "Invalid observation_type type: " + self.observation_type ) # 1}}} - if self.infer_mode == "qwen2vl_user" or self.infer_mode == "qwen25vl_normal": + if self.infer_mode == "qwen2vl_user": user_prompt = self.prompt_template.format( instruction=instruction, action_space=self.prompt_action_space, @@ -726,6 +559,8 @@ class UITARSAgent: if len(self.history_images) > self.history_n: self.history_images = self.history_images[-self.history_n:] + max_pixels = 2116800 + min_pixels = 3136 messages, images = [], [] if isinstance(self.history_images, bytes): self.history_images = [self.history_images] @@ -735,24 +570,28 @@ class UITARSAgent: pass else: raise TypeError(f"Unidentified images type: {type(self.history_images)}") + max_image_nums_under_32k = int(32768*0.75/max_pixels*28*28) + if len(self.history_images) > max_image_nums_under_32k: + num_of_images = min(5, len(self.history_images)) + max_pixels = int(32768*0.75) // num_of_images for turn, image in enumerate(self.history_images): - if len(images) >= self.history_n: + if len(images) >= 5: break try: image = Image.open(BytesIO(image)) except Exception as e: raise RuntimeError(f"Error opening image: {e}") - if image.width * image.height > self.max_pixels: + if image.width * image.height > max_pixels: """ 如果图片超过/低于像素限制,则计算一个缩放因子resize_factor,使图片的像素数缩小到等于或小于max_pixels。这个缩放因子是通过开平方根计算的,确保纵横比保持不变,这样原始的相对坐标可以不经转换直接复用 """ - resize_factor = math.sqrt(self.max_pixels / (image.width * image.height)) + resize_factor = math.sqrt(max_pixels / (image.width * image.height)) width, height = int(image.width * resize_factor), int(image.height * resize_factor) image = image.resize((width, height)) - if image.width * image.height < self.min_pixels: - resize_factor = math.sqrt(self.min_pixels / (image.width * image.height)) + if image.width * image.height < min_pixels: + resize_factor = math.sqrt(min_pixels / (image.width * image.height)) width, height = math.ceil(image.width * resize_factor), math.ceil(image.height * resize_factor) image = image.resize((width, height)) @@ -788,7 +627,7 @@ class UITARSAgent: messages.append({ "role": "assistant", - "content": [add_box_token(history_response)] + "content": history_response }) cur_image = images[image_num] @@ -809,79 +648,59 @@ class UITARSAgent: image_num += 1 try_times = 3 - origin_resized_height = images[-1].height - origin_resized_width = images[-1].width - temperature = self.temperature - top_k = self.top_k while True: if try_times <= 0: print(f"Reach max retry times to fetch response from client, as error flag.") - return "client error", ["DONE"], [] + return "client error", ["DONE"] try: + response = self.vlm.chat.completions.create( - model="ui-tars", + model=self.model, messages=messages, frequency_penalty=1, max_tokens=self.max_tokens, - temperature=temperature, + temperature=self.temperature, top_p=self.top_p ) - # print(response.choices[0].message.content) - prediction = response.choices[0].message.content.strip() - except Exception as e: - print(f"Error when fetching response from client, with response: {response}") - prediction = None - try_times -= 1 - - try: - parsed_responses = parse_action_to_structure_output( + print("Response:") + print(response.choices[0].message.content) + + prediction = response.choices[0].message.content + parsed_responses = self.customize_action_parser( prediction, self.action_parse_res_factor, - origin_resized_height, - origin_resized_width, - self.model_type, - self.max_pixels, - self.min_pixels + self.runtime_conf["screen_height"], + self.runtime_conf["screen_width"] ) break except Exception as e: - print(f"Error when parsing response from client, with response: {response}") - # If fail to parse the model response, we use sampling parameters to avoid it + logger.exception(f"Error when fetching response from client, with response: {e}") prediction = None try_times -= 1 - temperature = 1 - top_k = -1 if prediction is None: return "client error", ["DONE"] - + self.history_responses.append(prediction) self.thoughts.append(prediction) try: - parsed_responses = parse_action_to_structure_output( + parsed_responses = self.customize_action_parser( prediction, self.action_parse_res_factor, - origin_resized_height, - origin_resized_width, - self.model_type, - self.max_pixels, - self.min_pixels + self.runtime_conf["screen_height"], + self.runtime_conf["screen_width"] ) except Exception as e: print(f"Parsing action error: {prediction}, with error:\n{e}") return f"Parsing action error: {prediction}, with error:\n{e}", ["DONE"] actions = [] - last_image = Image.open(BytesIO(self.history_images[-1])) - obs_image_height = last_image.height - obs_image_width = last_image.width for parsed_response in parsed_responses: if "action_type" in parsed_response: if parsed_response["action_type"] == FINISH_WORD: self.actions.append(actions) - return prediction, ["DONE"] elif parsed_response["action_type"] == WAIT_WORD: @@ -893,18 +712,13 @@ class UITARSAgent: return prediction, ["FAIL"] elif parsed_response["action_type"] == CALL_USER: - if self.callusr_tolerance > self.cur_callusr_count: - self.actions.append(actions) - self.cur_callusr_count += 1 - return prediction, ["WAIT"] - else: - self.actions.append(actions) - return prediction, ["FAIL"] + self.actions.append(actions) + return prediction, ["FAIL"] pyautogui_code = parsing_response_to_pyautogui_code( parsed_response, - obs_image_height, - obs_image_width, + self.runtime_conf["screen_height"], + self.runtime_conf["screen_width"], self.input_swap ) actions.append(pyautogui_code) @@ -917,7 +731,6 @@ class UITARSAgent: return prediction, actions - @backoff.on_exception( backoff.constant, # here you should add more model exceptions as you want, @@ -947,4 +760,4 @@ class UITARSAgent: self.actions = [] self.observations = [] self.history_images = [] - self.history_responses = [] + self.history_responses = [] \ No newline at end of file diff --git a/run_multienv_uitars.py b/run_multienv_uitars.py new file mode 100644 index 0000000..1c95ee6 --- /dev/null +++ b/run_multienv_uitars.py @@ -0,0 +1,539 @@ +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List +from multiprocessing import Process, Manager +from multiprocessing import current_process +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.uitars_agent import UITARSAgent +import os + + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=3.0) + parser.add_argument("--max_steps", type=int, default=15) + + # evaluation config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + # lm config + parser.add_argument("--model", type=str, default="uitars-72b-dpo", help="Model name") + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + parser.add_argument("--max_trajectory_length", type=int, default=3, help="The max number of trajectory steps.") + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + args.max_trajectory_length = args.max_steps + agent = UITARSAgent( + model=args.model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + ) + + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") + with Manager() as manager: + shared_scores = manager.list() + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs + processes = [] + for i in range(num_envs): + p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" + ) + p.daemon = True + p.start() + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + scores = list(shared_scores) + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") diff --git a/run_multienv_uitars15_v1.py b/run_multienv_uitars15_v1.py new file mode 100644 index 0000000..77a538d --- /dev/null +++ b/run_multienv_uitars15_v1.py @@ -0,0 +1,581 @@ +from __future__ import annotations +import argparse +import datetime +import json +import logging +import os +import sys +import signal +import time +from typing import List +from multiprocessing import Process, Manager +from multiprocessing import current_process +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.uitars15_v1 import UITARSAgent +import os + +# Global variables for signal handling +active_environments = [] +processes = [] +is_terminating = False + +# load the environment variables from .env file +if os.path.exists(".env"): + from dotenv import load_dotenv + load_dotenv() + +# Logger Configs {{{ # +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--sleep_after_execution", type=float, default=3.0) + parser.add_argument("--max_steps", type=int, default=15) + + # evaluation config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--model", type=str, default="uitars15-7b") + parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen25vl", "qwen2vl"]) + parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal", choices=["qwen25vl_normal", "qwen2vl_user"]) + parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal") + parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content") + parser.add_argument("--language", type=str, default="Chinese") + parser.add_argument("--max_pixels", type=float, default=16384*28*28) + parser.add_argument("--min_pixels", type=float, default=100*28*28) + parser.add_argument("--temperature", type=float, default=1.0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--top_k", type=int, default=-1) + parser.add_argument("--history_n", type=int, default=5) + parser.add_argument("--callusr_tolerance", type=int, default=3) + parser.add_argument("--max_tokens", type=int, default=500) + parser.add_argument("--stop_token", type=str, default=None) + + parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.") + parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.") + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + parser.add_argument("--log_level", type=str, choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], + default='INFO', help="Set the logging level") + # aws config + parser.add_argument( + "--region", type=str, default="us-east-1", help="AWS region for the VM" + ) + parser.add_argument( + "--provider_name", type=str, default="aws", choices=["aws", "virtualbox", "vmware", "docker", "azure"], help="Provider name" + ) + parser.add_argument( + "--client_password", type=str, default="", help="Client password" + ) + parser.add_argument( + "--screen_width", type=int, default=1920, help="Screen width" + ) + parser.add_argument( + "--screen_height", type=int, default=1080, help="Screen height" + ) + args = parser.parse_args() + return args + +args = config() # Get command line arguments first + +logger = logging.getLogger() +log_level = getattr(logging, args.log_level.upper()) +logger.setLevel(log_level) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(log_level) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def distribute_tasks(test_all_meta: dict) -> List[tuple]: + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + return all_tasks + + +def process_signal_handler(signum, frame, env_idx): + """Signal handler for child processes to gracefully shut down their environments.""" + logger.info(f"Process {env_idx + 1} received signal {signum}. Shutting down...") + + # Get the active_environments from the caller's frame + local_vars = frame.f_locals + active_environments = local_vars.get('active_environments', []) + + # Close environment in the current process context + for env in active_environments: + if env is not None: + try: + logger.info(f"Process {env_idx + 1} closing environment...") + env.close() + logger.info(f"Process {env_idx + 1} environment closed successfully") + except Exception as e: + logger.error(f"Process {env_idx + 1} error closing environment: {e}") + + logger.info(f"Process {env_idx + 1} shutdown complete. Exiting.") + sys.exit(0) + +def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: list): + active_environments = [] + env = None + try: + from desktop_env.providers.aws.manager import IMAGE_ID_MAP + REGION = args.region + screen_size = (args.screen_width, args.screen_height) + ami_id = IMAGE_ID_MAP[REGION].get(screen_size, IMAGE_ID_MAP[REGION][(1920, 1080)]) + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=args.action_space, + provider_name=args.provider_name, + region=REGION, + snapshot_name=ami_id, + screen_size=screen_size, + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type in ["a11y_tree", "screenshot_a11y_tree", "som"], + enable_proxy=True, + client_password=args.client_password + ) + active_environments.append(env) + args.max_trajectory_length = args.max_steps + if args.infer_mode == "qwen25vl_normal": + runtime_conf: dict = { + "infer_mode": "qwen25vl_normal", + "prompt_style": "qwen25vl_normal", + "input_swap": True, + "language": "Chinese", + "history_n": 5, + "max_pixels": 16384*28*28, + "min_pixels": 100*28*28, + "callusr_tolerance": 3, + "temperature": 0.0, + "top_k": -1, + "top_p": 0.9, + "max_tokens": 1000 + + } + elif args.infer_mode == "qwen2vl_user": + runtime_conf: dict = { + "infer_mode": "qwen2vl_user", + "prompt_style": "qwen2vl_user", + "input_swap": True, + "language": "Chinese", + "history_n": 5, + "max_pixels": 2116800, + "min_pixels": 3136, + "callusr_tolerance": 3, + "temperature": 0.0, + "top_k": -1, + "top_p": 0.9, + "max_tokens": 1000 + } + else: + raise ValueError(f"Unknown infer_mode: {args.infer_mode}") + + agent = UITARSAgent( + model=args.model, + action_space=args.action_space, + observation_type=args.observation_type, + max_trajectory_length=args.max_trajectory_length, + model_type=args.model_type, + runtime_conf = runtime_conf + ) + + logger.info(f"Process {current_process().name} started.") + while True: + try: + item = task_queue.get(timeout=5) + except Exception: + break + domain, example_id = item + try: + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + logger.info(f"[{current_process().name}][Domain]: {domain}") + logger.info(f"[{current_process().name}][Example ID]: {example_id}") + logger.info(f"[{current_process().name}][Instruction]: {example['instruction']}") + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + import traceback + logger.error(f"Exception in {current_process().name} {domain}/{example_id}: {e}") + logger.error(traceback.format_exc()) + try: + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + except Exception as rec_e: + logger.error(f"Failed to end recording: {rec_e}") + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"{domain}/{example_id} - {e}"} + ) + ) + f.write("\n") + except Exception as e: + logger.error(f"Task-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + except Exception as e: + logger.error(f"Process-level error in {current_process().name}: {e}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"{current_process().name} cleaning up environment...") + try: + if env: + env.close() + logger.info(f"{current_process().name} environment closed successfully") + except Exception as e: + logger.error(f"{current_process().name} error during environment cleanup: {e}") + + + +def signal_handler(signum, frame): + """Handle termination signals (SIGINT, SIGTERM) to gracefully shutdown environments.""" + global is_terminating, active_environments, processes + + # Avoid duplicate handling + if is_terminating: + return + + is_terminating = True + logger.info(f"Received signal {signum}. Gracefully shutting down...") + + # Close all registered environments in the main process + for env in active_environments: + try: + logger.info(f"Closing environment...") + env.close() + logger.info(f"Environment closed successfully") + except Exception as e: + logger.error(f"Error closing environment: {e}") + + # Send termination signal to all child processes first + for p in processes: + if p.is_alive(): + try: + logger.info(f"Sending termination signal to process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error sending termination signal to process: {e}") + + # Allow a short time for processes to handle their own cleanup + time.sleep(1) + + # Forcefully terminate any processes that didn't exit + for p in processes: + if p.is_alive(): + try: + logger.info(f"Forcefully terminating process {p.name}...") + import signal as sig + os.kill(p.pid, sig.SIGKILL) + except Exception as e: + logger.error(f"Error forcefully terminating process: {e}") + + logger.info("Shutdown complete. Exiting.") + sys.exit(0) + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + global processes + logger.info("Args: %s", args) + all_tasks = distribute_tasks(test_all_meta) + logger.info(f"Total tasks: {len(all_tasks)}") + with Manager() as manager: + shared_scores = manager.list() + task_queue = manager.Queue() + for item in all_tasks: + task_queue.put(item) + num_envs = args.num_envs + processes = [] + for i in range(num_envs): + p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-{i+1}" + ) + p.daemon = True + p.start() + processes.append(p) + logger.info(f"Started process {p.name} with PID {p.pid}") + try: + while True: + alive_count = 0 + for idx, p in enumerate(processes): + if not p.is_alive(): + logger.warning(f"Process {p.name} died, restarting...") + new_p = Process( + target=run_env_tasks, + args=(task_queue, args, shared_scores), + name=f"EnvProcess-Restart-{idx+1}" + ) + new_p.daemon = True + new_p.start() + processes[idx] = new_p + logger.info(f"Restarted process {new_p.name} with PID {new_p.pid}") + else: + alive_count += 1 + if task_queue.empty(): + logger.info("All tasks finished.") + break + if alive_count == 0: + logger.error("All processes died, exiting.") + break + time.sleep(5) + for p in processes: + p.join() + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt. Initiating graceful shutdown...") + raise + except Exception as e: + logger.error(f"Unexpected error while waiting for processes: {e}", exc_info=True) + for p in processes: + if p.is_alive(): + try: + logger.info(f"Terminating process {p.name} due to error...") + p.terminate() + except Exception as term_e: + logger.error(f"Error terminating process {p.name}: {term_e}") + raise + scores = list(shared_scores) + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + # Register signal handlers for graceful termination + signal.signal(signal.SIGINT, signal_handler) # Handle Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Handle termination signal + + try: + args = config() + + # save args to json in result_dir/action_space/observation_type/model/args.json + path_to_args = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + args.model, + "args.json", + ) + os.makedirs(os.path.dirname(path_to_args), exist_ok=True) + with open(path_to_args, "w", encoding="utf-8") as f: + json.dump(vars(args), f, indent=4) + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + test_file_list = get_unfinished( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + args.model, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list) + except KeyboardInterrupt: + logger.info("Main process received KeyboardInterrupt.") + # Signal handler will take care of cleanup + except Exception as e: + logger.error(f"Unexpected error in main process: {e}", exc_info=True) + # Also trigger cleanup for unhandled exceptions + signal_handler(signal.SIGTERM, None) + finally: + # Final cleanup in case any environments or processes remain + logger.info("Main process final cleanup...") + for env in active_environments: + if env is not None: + try: + logger.info(f"Closing environment in final cleanup...") + env.close() + logger.info(f"Environment closed successfully in final cleanup") + except Exception as e: + logger.error(f"Error during final environment cleanup: {e}") + + # First try gentle termination + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Terminating process {p.name}...") + p.terminate() + except Exception as e: + logger.error(f"Error terminating process: {e}") + + # Wait a moment for processes to terminate + time.sleep(1) + + # Then force kill if needed + for p in processes: + if p is not None and p.is_alive(): + try: + logger.info(f"Force killing process {p.name}...") + os.kill(p.pid, signal.SIGKILL) + logger.info(f"Process {p.name} force killed") + except Exception as e: + logger.error(f"Error force killing process: {e}") diff --git a/run_multienv_uitars15.py b/run_multienv_uitars15_v2.py similarity index 97% rename from run_multienv_uitars15.py rename to run_multienv_uitars15_v2.py index 5d25fd8..aab3d1a 100644 --- a/run_multienv_uitars15.py +++ b/run_multienv_uitars15_v2.py @@ -8,31 +8,13 @@ import sys import signal import time from typing import List, Dict -import math -from tqdm import tqdm from multiprocessing import Process, Manager from multiprocessing import current_process import lib_run_single from desktop_env.desktop_env import DesktopEnv -from mm_agents.uitars15_agent import UITarsAgent - -import shutil +from mm_agents.uitars15_v2 import UITarsAgent import os -# def clear_cache(): -# cache_path = "cache" - -# try: -# if os.path.exists(cache_path): -# logger.info(f"Deleting cache directory: {cache_path}") -# shutil.rmtree(cache_path) -# logger.info(f"Cache directory deleted successfully") -# else: -# logger.info(f"Cache directory {cache_path} does not exist") -# except Exception as e: -# logger.error(f"Error deleting cache directory: {e}") - -# clear_cache() # Global variables for signal handling active_environments = [] @@ -74,12 +56,12 @@ def config() -> argparse.Namespace: # lm config parser.add_argument("--model", type=str, default="doubao-1-5-thinking-vision-pro-250428") + parser.add_argument("--model_type", type=str, default="doubao", choices=["doubao", "qwen25"]) parser.add_argument("--temperature", type=float, default=0) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--max_tokens", type=int, default=3000) parser.add_argument("--use_thinking", action="store_true", default=False) - # OpenCUAagent config parser.add_argument("--max_trajectory_length", type=int, default=None, help="The max number of trajectory steps.") parser.add_argument("--max_image_history_length", type=int, default=5, help="The max number of images in the history.") parser.add_argument("--language", type=str, default="Chinese", help="Language for the agent.") @@ -204,6 +186,7 @@ def run_env_tasks(task_queue: Queue, args: argparse.Namespace, shared_scores: li active_environments.append(env) agent = UITarsAgent( model=args.model, + model_type=args.model_type, max_tokens=args.max_tokens, top_p=args.top_p, temperature=args.temperature, diff --git a/run_uitars.py b/run_uitars.py deleted file mode 100644 index 3b6ea84..0000000 --- a/run_uitars.py +++ /dev/null @@ -1,362 +0,0 @@ -"""Script to run end-to-end evaluation on the benchmark. -Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. -""" - -import argparse -import datetime -import json -import logging -import os -import sys - -from tqdm import tqdm - -import lib_run_single -from desktop_env.desktop_env import DesktopEnv -from mm_agents.uitars_agent import UITARSAgent - -# import wandb - - -# Logger Configs {{{ # -logger = logging.getLogger() -logger.setLevel(logging.DEBUG) - -datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") - -file_handler = logging.FileHandler( - os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" -) -debug_handler = logging.FileHandler( - os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" -) -stdout_handler = logging.StreamHandler(sys.stdout) -sdebug_handler = logging.FileHandler( - os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" -) - -file_handler.setLevel(logging.INFO) -debug_handler.setLevel(logging.DEBUG) -stdout_handler.setLevel(logging.INFO) -sdebug_handler.setLevel(logging.DEBUG) - -formatter = logging.Formatter( - fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" -) -file_handler.setFormatter(formatter) -debug_handler.setFormatter(formatter) -stdout_handler.setFormatter(formatter) -sdebug_handler.setFormatter(formatter) - -stdout_handler.addFilter(logging.Filter("desktopenv")) -sdebug_handler.addFilter(logging.Filter("desktopenv")) - -logger.addHandler(file_handler) -logger.addHandler(debug_handler) -logger.addHandler(stdout_handler) -logger.addHandler(sdebug_handler) -# }}} Logger Configs # - -logger = logging.getLogger("desktopenv.experiment") - - -def config() -> argparse.Namespace: - parser = argparse.ArgumentParser( - description="Run end-to-end evaluation on the benchmark" - ) - - # environment config - parser.add_argument("--path_to_vm", type=str, default=None) - parser.add_argument( - "--headless", action="store_true", help="Run in headless machine" - ) - parser.add_argument( - "--action_space", type=str, default="pyautogui", help="Action type" - ) - parser.add_argument( - "--observation_type", - choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], - default="a11y_tree", - help="Observation type", - ) - parser.add_argument("--screen_width", type=int, default=1920) - parser.add_argument("--screen_height", type=int, default=1080) - parser.add_argument("--sleep_after_execution", type=float, default=0.0) - parser.add_argument("--max_steps", type=int, default=15) - - # agent config - parser.add_argument("--max_trajectory_length", type=int, default=3) - parser.add_argument( - "--test_config_base_dir", type=str, default="evaluation_examples" - ) - - # lm config - parser.add_argument("--model", type=str, default="uitars") - parser.add_argument("--model_type", type=str, default="qwen25vl") - parser.add_argument("--infer_mode", type=str, default="qwen25vl_normal") - parser.add_argument("--prompt_style", type=str, default="qwen25vl_normal") - parser.add_argument("--input_swap", action="store_true", help="Use copy and paste to type content") - parser.add_argument("--language", type=str, default="Chinese") - parser.add_argument("--max_pixels", type=float, default=16384*28*28) - parser.add_argument("--min_pixels", type=float, default=100*28*28) - parser.add_argument("--temperature", type=float, default=1.0) - parser.add_argument("--top_p", type=float, default=0.9) - parser.add_argument("--top_k", type=int, default=-1) - parser.add_argument("--history_n", type=int, default=5) - parser.add_argument("--callusr_tolerance", type=int, default=3) - parser.add_argument("--max_tokens", type=int, default=500) - parser.add_argument("--stop_token", type=str, default=None) - - # example config - parser.add_argument("--domain", type=str, default="all") - parser.add_argument( - "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" - ) - - # logging related - parser.add_argument("--result_dir", type=str, default="./results") - args = parser.parse_args() - - return args - - -def test(args: argparse.Namespace, test_all_meta: dict) -> None: - scores = [] - max_steps = args.max_steps - - # log args - logger.info("Args: %s", args) - # set wandb project - cfg_args = { - "path_to_vm": args.path_to_vm, - "headless": args.headless, - "action_space": args.action_space, - "observation_type": args.observation_type, - "screen_width": args.screen_width, - "screen_height": args.screen_height, - "sleep_after_execution": args.sleep_after_execution, - "max_steps": args.max_steps, - "max_trajectory_length": args.max_trajectory_length, - "model": args.model, - "model_type": args.model_type, - "infer_mode": args.infer_mode, - "prompt_style": args.prompt_style, - "input_swap": args.input_swap, - "language": args.language, - "history_n": args.history_n, - "max_pixels": args.max_pixels, - "min_pixels": args.min_pixels, - "callusr_tolerance": args.callusr_tolerance, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "max_tokens": args.max_tokens, - "stop_token": args.stop_token, - "result_dir": args.result_dir, - } - - agent = UITARSAgent( - action_space=args.action_space, - observation_type=args.observation_type, - max_trajectory_length=args.max_trajectory_length, - model_type=args.model_type, - runtime_conf = { - "infer_mode": args.infer_mode, - "prompt_style": args.prompt_style, - "input_swap": args.input_swap, - "language": args.language, - "history_n": args.history_n, - "max_pixels": args.max_pixels, - "min_pixels": args.min_pixels, - "callusr_tolerance": args.callusr_tolerance, - "temperature": args.temperature, - "top_p": args.top_p, - "top_k": args.top_k, - "max_tokens": args.max_tokens - } - ) - - env = DesktopEnv( - path_to_vm=args.path_to_vm, - action_space=agent.action_space, - screen_size=(args.screen_width, args.screen_height), - headless=args.headless, - os_type = "Ubuntu", - require_a11y_tree=args.observation_type - in ["a11y_tree", "screenshot_a11y_tree", "som"], - ) - - for domain in tqdm(test_all_meta, desc="Domain"): - for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): - config_file = os.path.join( - args.test_config_base_dir, f"examples/{domain}/{example_id}.json" - ) - with open(config_file, "r", encoding="utf-8") as f: - example = json.load(f) - - logger.info(f"[Domain]: {domain}") - logger.info(f"[Example ID]: {example_id}") - - instruction = example["instruction"] - - logger.info(f"[Instruction]: {instruction}") - # wandb each example config settings - cfg_args["instruction"] = instruction - cfg_args["start_time"] = datetime.datetime.now().strftime( - "%Y:%m:%d-%H:%M:%S" - ) - # run.config.update(cfg_args) - - example_result_dir = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - domain, - example_id, - ) - os.makedirs(example_result_dir, exist_ok=True) - # example start running - try: - lib_run_single.run_single_example( - agent, - env, - example, - max_steps, - instruction, - args, - example_result_dir, - scores, - ) - except Exception as e: - logger.error(f"Exception in {domain}/{example_id}: {e}") - env.controller.end_recording( - os.path.join(example_result_dir, "recording.mp4") - ) - with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: - f.write( - json.dumps( - {"Error": f"Time limit exceeded in {domain}/{example_id}"} - ) - ) - f.write("\n") - - env.close() - logger.info(f"Average score: {sum(scores) / len(scores)}") - - -def get_unfinished( - action_space, use_model, observation_type, result_dir, total_file_json -): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) - - if not os.path.exists(target_dir): - return total_file_json - - finished = {} - for domain in os.listdir(target_dir): - finished[domain] = [] - domain_path = os.path.join(target_dir, domain) - if os.path.isdir(domain_path): - for example_id in os.listdir(domain_path): - if example_id == "onboard": - continue - example_path = os.path.join(domain_path, example_id) - if os.path.isdir(example_path): - if "result.txt" not in os.listdir(example_path): - # empty all files under example_id - for file in os.listdir(example_path): - os.remove(os.path.join(example_path, file)) - else: - finished[domain].append(example_id) - - if not finished: - return total_file_json - - for domain, examples in finished.items(): - if domain in total_file_json: - total_file_json[domain] = [ - x for x in total_file_json[domain] if x not in examples - ] - - return total_file_json - - -def get_result(action_space, use_model, observation_type, result_dir, total_file_json): - target_dir = os.path.join(result_dir, action_space, observation_type, use_model) - if not os.path.exists(target_dir): - print("New experiment, no result yet.") - return None - - all_result = [] - - for domain in os.listdir(target_dir): - domain_path = os.path.join(target_dir, domain) - if os.path.isdir(domain_path): - for example_id in os.listdir(domain_path): - example_path = os.path.join(domain_path, example_id) - if os.path.isdir(example_path): - if "result.txt" in os.listdir(example_path): - # empty all files under example_id - try: - all_result.append( - float( - open( - os.path.join(example_path, "result.txt"), "r" - ).read() - ) - ) - except: - all_result.append(0.0) - - if not all_result: - print("New experiment, no result yet.") - return None - else: - print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") - return all_result - - -if __name__ == "__main__": - ####### The complete version of the list of examples ####### - os.environ["TOKENIZERS_PARALLELISM"] = "false" - args = config() - - # save args to json in result_dir/action_space/observation_type/model/args.json - path_to_args = os.path.join( - args.result_dir, - args.action_space, - args.observation_type, - args.model, - "args.json", - ) - os.makedirs(os.path.dirname(path_to_args), exist_ok=True) - with open(path_to_args, "w", encoding="utf-8") as f: - json.dump(vars(args), f, indent=4) - - with open(args.test_all_meta_path, "r", encoding="utf-8") as f: - test_all_meta = json.load(f) - - if args.domain != "all": - test_all_meta = {args.domain: test_all_meta[args.domain]} - - test_file_list = get_unfinished( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - left_info = "" - for domain in test_file_list: - left_info += f"{domain}: {len(test_file_list[domain])}\n" - logger.info(f"Left tasks:\n{left_info}") - - get_result( - args.action_space, - args.model, - args.observation_type, - args.result_dir, - test_all_meta, - ) - test(args, test_file_list)