From 53fb96298a05fa7924ff4e6f2b93b0ec66b23fa3 Mon Sep 17 00:00:00 2001 From: Dunjie Lu <127488745+ludunjie1219@users.noreply.github.com> Date: Tue, 22 Jul 2025 16:33:03 +0800 Subject: [PATCH] support_qwen25vl (#276) Co-authored-by: root --- mm_agents/qwen25vl_agent.py | 582 +++++++++++++++++++++++++++++++ mm_agents/utils/qwen_vl_utils.py | 271 ++++++++++++++ run_multienv_qwen25vl.py | 362 +++++++++++++++++++ 3 files changed, 1215 insertions(+) create mode 100644 mm_agents/qwen25vl_agent.py create mode 100644 mm_agents/utils/qwen_vl_utils.py create mode 100644 run_multienv_qwen25vl.py diff --git a/mm_agents/qwen25vl_agent.py b/mm_agents/qwen25vl_agent.py new file mode 100644 index 0000000..20d30bc --- /dev/null +++ b/mm_agents/qwen25vl_agent.py @@ -0,0 +1,582 @@ +import base64 +import json +import logging +import time +import os +from io import BytesIO +from typing import Dict, List, Tuple + +import backoff +import openai +from PIL import Image +from requests.exceptions import SSLError +from google.api_core.exceptions import ( + InvalidArgument, + ResourceExhausted, + InternalServerError, + BadRequest, +) +from mm_agents.utils.qwen_vl_utils import smart_resize + + + +logger = None + +MAX_RETRY_TIMES = 5 + +def encode_image(image_content): + return base64.b64encode(image_content).decode("utf-8") + + +def process_image(image_bytes): + """ + Process an image for Qwen VL models. + Resize the image to dimensions expected by the model. + + Args: + image_bytes: Raw image bytes + + Returns: + Base64 encoded image string of the processed image + """ + # Open image from bytes + image = Image.open(BytesIO(image_bytes)) + width, height = image.size + + # Calculate resized dimensions + resized_height, resized_width = smart_resize( + height=height, + width=width + ) + + # Resize the image + image = image.resize((resized_width, resized_height)) + + # Convert to bytes + buffer = BytesIO() + image.save(buffer, format="PNG") + processed_bytes = buffer.getvalue() + + # Return base64 encoded string + return base64.b64encode(processed_bytes).decode('utf-8') + + +class Qwen25VLAgent: + + def __init__( + self, + platform="ubuntu", + planner_model="gpt-4o", + executor_model="qwen2.5vl", + max_tokens=1500, + top_p=0.9, + temperature=0.5, + action_space="pyautogui", + observation_type="screenshot", + history_n=4, # Number of previous interactions to include in full detail + ): + self.platform = platform + self.planner_model = planner_model + self.executor_model = executor_model + assert self.executor_model is not None, "Executor model cannot be None" + self.max_tokens = max_tokens + self.top_p = top_p + self.temperature = temperature + self.action_space = action_space + self.observation_type = observation_type + self.history_n = history_n # Control how many previous interactions to include + assert action_space in ["pyautogui"], "Invalid action space" + assert observation_type in ["screenshot"], "Invalid observation type" + self.thoughts = [] + self.actions = [] + self.observations = [] + self.responses = [] # Store model responses + self.screenshots = [] # Store processed screenshots + + def predict(self, instruction: str, obs: Dict) -> List: + """ + Predict the next action(s) based on the current observation. + """ + # Process the screenshot image + screenshot_bytes = obs["screenshot"] + + # Display original dimensions + image = Image.open(BytesIO(screenshot_bytes)) + width, height = image.size + print(f"Original screen resolution: {width}x{height}") + + # Process the image + processed_image = process_image(screenshot_bytes) + processed_img = Image.open(BytesIO(base64.b64decode(processed_image))) + processed_width, processed_height = processed_img.size + print(f"Processed image resolution: {processed_width}x{processed_height}") + + # Save the current screenshot to history + self.screenshots.append(processed_image) + + # Calculate history window start index + current_step = len(self.actions) + history_start_idx = max(0, current_step - self.history_n) + + # Build previous actions string - only include actions outside the history window + previous_actions = [] + for i in range(history_start_idx): + if i < len(self.actions): + previous_actions.append(f"Step {i+1}: {self.actions[i]}") + previous_actions_str = "\n".join(previous_actions) if previous_actions else "None" + + # System prompt with tool definition + tools_def = { + "type": "function", + "function": { + "name_for_human": "computer_use", + "name": "computer_use", + "description": "Use a mouse and keyboard to interact with a computer, and take screenshots.", + "parameters": { + "properties": { + "action": { + "description": "The action to perform.", + "enum": ["key", "type", "mouse_move", "left_click", "left_click_drag", + "right_click", "middle_click", "double_click", "scroll", "wait", "terminate"], + "type": "string" + }, + "keys": {"description": "Required only by `action=key`.", "type": "array"}, + "text": {"description": "Required only by `action=type`.", "type": "string"}, + "coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"}, + "pixels": {"description": "The amount of scrolling.", "type": "number"}, + "time": {"description": "The seconds to wait.", "type": "number"}, + "status": { + "description": "The status of the task.", + "type": "string", + "enum": ["success", "failure"] + } + }, + "required": ["action"], + "type": "object" + }, + "args_format": "Format the arguments as a JSON object." + } + } + + system_prompt = """You are a helpful assistant + +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +""" + json.dumps(tools_def) + """ + + +For each function call, return a json object with function name and arguments within XML tags: + +{"name": , "arguments": } +""" + + # Create instruction prompt + instruction_prompt = f""" +Please generate the next move according to the UI screenshot, instruction and previous actions. + +Instruction: {instruction} + +Previous actions: +{previous_actions_str}""" + + # Initialize messages with system prompt + messages = [ + { + "role": "system", + "content": [{ + "type": "text", + "text": system_prompt + }] + } + ] + + # Add history responses and images within the history window + history_len = min(self.history_n, len(self.responses)) + if history_len > 0: + # Only include the most recent history_n steps + history_responses = self.responses[-history_len:] + history_screenshots = self.screenshots[-history_len-1:-1] # Include one more for the previous screenshot + + # Add history in conversation format + for idx in range(history_len): + # Add the screenshot (user message) + if idx < len(history_screenshots): + screenshot_b64 = history_screenshots[idx] + + # If this is the first history item, include instruction prompt + if idx == 0: + messages.append({ + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{screenshot_b64}" + } + }, + { + "type": "text", + "text": instruction_prompt + } + ] + }) + else: + messages.append({ + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{screenshot_b64}" + } + } + ] + }) + + # Add the action and response (assistant message) + + messages.append({ + "role": "assistant", + "content": [ + {"type": "text", "text": history_responses[idx]} + ] + }) + + # Add the current screenshot without instruction (since we already have history) + messages.append({ + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{processed_image}" + } + } + ] + }) + else: + # If no history, just add current screenshot with instruction prompt + messages.append({ + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{processed_image}" + } + }, + { + "type": "text", + "text": instruction_prompt + } + ] + }) + + # append_text = f"""Step {current_step+1}: Thought:""" + append_text = f"""Thought:""" + messages.append({"role": "assistant", "content": [{"type": "text", "text": append_text}]}) + + # Call the LLM + response = self.call_llm( + { + "model": self.executor_model, + "messages": messages, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "temperature": self.temperature, + }, + self.executor_model, + ) + + logger.info(f"Qwen25VL Output: {response}") + + # Save response to history + self.responses.append(response) + + # Parse response and extract pyautogui code + low_level_instruction, pyautogui_code = self.parse_response( + response, + width, + height, + processed_width, + processed_height + ) + + logger.info(f"Low level instruction: {low_level_instruction}") + logger.info(f"Pyautogui code: {pyautogui_code}") + + # Add the action to history + self.actions.append(low_level_instruction) + + return response, pyautogui_code + + def parse_response(self, response: str, original_width: int = None, original_height: int = None, + processed_width: int = None, processed_height: int = None) -> Tuple[str, List[str]]: + """ + Parse LLM response and convert it to low level action and pyautogui code. + + Args: + response: Raw response string from the model + original_width: Width of the original screenshot + original_height: Height of the original screenshot + processed_width: Width of the processed image + processed_height: Height of the processed image + + Returns: + Tuple of (low_level_instruction, list of pyautogui_commands) + """ + low_level_instruction = "" + pyautogui_code = [] + + if response is None or not response.strip(): + return low_level_instruction, pyautogui_code + + # Define function to adjust coordinates based on original and processed dimensions + def adjust_coordinates(x: float, y: float) -> Tuple[int, int]: + """ + Adjust coordinates from processed image dimensions to original image dimensions. + """ + if all([original_width, original_height, processed_width, processed_height]): + # Calculate the scale factors between original and processed images + x_scale = original_width / processed_width + y_scale = original_height / processed_height + + # Apply scaling to get coordinates in original image space + adjusted_x = int(x * x_scale) + adjusted_y = int(y * y_scale) + + return adjusted_x, adjusted_y + else: + # If any dimension is missing, return the original coordinates + return int(x), int(y) + + # Define inner function to process tool calls + def process_tool_call(json_str: str) -> None: + """Process a single tool call JSON string.""" + try: + # Parse the JSON + tool_call = json.loads(json_str) + if tool_call.get("name") == "computer_use": + # Convert computer_use actions to pyautogui commands + args = tool_call["arguments"] + action = args["action"] + + if action == "left_click": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})") + else: + pyautogui_code.append("pyautogui.click()") + + elif action == "right_click": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + pyautogui_code.append(f"pyautogui.rightClick({adj_x}, {adj_y})") + else: + pyautogui_code.append("pyautogui.rightClick()") + + elif action == "middle_click": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + pyautogui_code.append(f"pyautogui.middleClick({adj_x}, {adj_y})") + else: + pyautogui_code.append("pyautogui.middleClick()") + + elif action == "double_click": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + pyautogui_code.append(f"pyautogui.doubleClick({adj_x}, {adj_y})") + else: + pyautogui_code.append("pyautogui.doubleClick()") + + elif action == "type": + text = args.get("text", "") + pyautogui_code.append(f"pyautogui.typewrite('{text}')") + + elif action == "key": + keys = args.get("keys", []) + # Fix possible formatting issues in the keys parameter + if isinstance(keys, list): + # Clean up any formatting issues in the keys + cleaned_keys = [] + for key in keys: + # Check if the key has the "keys=[" prefix or "]" suffix + if isinstance(key, str): + # Remove "keys=[" prefix if present + if key.startswith("keys=["): + key = key[6:] + # Remove "]" suffix if present + if key.endswith("]"): + key = key[:-1] + # Handle case where string contains representation of list items + if key.startswith("['") or key.startswith("[\""): + key = key[2:] if len(key) > 2 else key + if key.endswith("']") or key.endswith("\"]"): + key = key[:-2] if len(key) > 2 else key + # Strip any extra whitespace + key = key.strip() + # Add to cleaned keys + cleaned_keys.append(key) + else: + cleaned_keys.append(key) + keys = cleaned_keys + + # Format the keys for hotkey or press command + keys_str = ", ".join([f"'{key}'" for key in keys]) + if len(keys) > 1: + pyautogui_code.append(f"pyautogui.hotkey({keys_str})") + else: + pyautogui_code.append(f"pyautogui.press({keys_str})") + + elif action == "scroll": + pixels = args.get("pixels", 0) + pyautogui_code.append(f"pyautogui.scroll({pixels})") + + elif action == "wait": + pyautogui_code.append("WAIT") # Special code for wait action + + elif action == "terminate": + pyautogui_code.append("DONE") # Special code for termination + + elif action == "mouse_move": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + pyautogui_code.append(f"pyautogui.moveTo({adj_x}, {adj_y})") + else: + pyautogui_code.append("pyautogui.moveTo(0, 0)") + + elif action == "left_click_drag": + if "coordinate" in args: + x, y = args["coordinate"] + adj_x, adj_y = adjust_coordinates(x, y) + duration = args.get("duration", 0.5) + pyautogui_code.append(f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})") + else: + pyautogui_code.append("pyautogui.dragTo(0, 0)") + except (json.JSONDecodeError, KeyError) as e: + logger.error(f"Failed to parse tool call: {e}") + + # Parse the response line by line + lines = response.split('\n') + inside_tool_call = False + current_tool_call = [] + + for line in lines: + line = line.strip() + if not line: + continue + + # Extract low-level instruction from lines starting with "Action:" or similar + if line.lower().startswith(("action:", "step", "i will", "i'll", "now i")): + if not low_level_instruction: + # Only store the first action description as low level instruction + low_level_instruction = line + continue + + # Handle lines inside tool call markers + if line.startswith(""): + inside_tool_call = True + continue + elif line.startswith(""): + if current_tool_call: + # Process the collected tool call + process_tool_call("\n".join(current_tool_call)) + current_tool_call = [] + inside_tool_call = False + continue + + if inside_tool_call: + current_tool_call.append(line) + continue + + # Try to parse individual lines as JSON + if line.startswith("{") and line.endswith("}"): + try: + json_obj = json.loads(line) + if "name" in json_obj and "arguments" in json_obj: + process_tool_call(line) + except json.JSONDecodeError: + pass + + # Process any remaining tool call content + if current_tool_call: + process_tool_call("\n".join(current_tool_call)) + + # If we still don't have a low-level instruction, generate a default one + if not low_level_instruction and len(pyautogui_code) > 0: + action_type = pyautogui_code[0].split(".", 1)[1].split("(", 1)[0] + low_level_instruction = f"Performing {action_type} action" + + return low_level_instruction, pyautogui_code + + @backoff.on_exception( + backoff.constant, + # here you should add more model exceptions as you want, + # but you are forbidden to add "Exception", that is, a common type of exception + # because we want to catch this kind of Exception in the outside to ensure + # each example won't exceed the time limit + ( + # General exceptions + SSLError, + # OpenAI exceptions + openai.RateLimitError, + openai.BadRequestError, + openai.InternalServerError, + # Google exceptions + InvalidArgument, + ResourceExhausted, + InternalServerError, + BadRequest, + # Groq exceptions + # todo: check + ), + interval=30, + max_tries=10, + ) + def call_llm(self, payload, model): + messages = payload["messages"] + base_url = "your_base_url" + api_key = "your_api_key" + + client = openai.OpenAI( + base_url=base_url, + api_key=api_key + ) + + for _ in range(MAX_RETRY_TIMES): + logger.info("Generating content with Qwen model: %s", model) + try: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"Error calling Qwen model: {e}") + time.sleep(5) + continue + return "" + + def reset(self, _logger=None): + global logger + logger = (_logger if _logger is not None else + logging.getLogger("desktopenv.qwen25vl_agent")) + + self.thoughts = [] + self.action_descriptions = [] + self.actions = [] + self.observations = [] + self.responses = [] # Reset responses + self.screenshots = [] # Reset screenshots diff --git a/mm_agents/utils/qwen_vl_utils.py b/mm_agents/utils/qwen_vl_utils.py new file mode 100644 index 0000000..f39088e --- /dev/null +++ b/mm_agents/utils/qwen_vl_utils.py @@ -0,0 +1,271 @@ +import math + + +def round_by_factor(number: int, factor: int) -> int: + """返回最接近 number 的且能被 factor 整除的整数""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """返回大于等于 number 的且能被 factor 整除的整数""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """返回小于等于 number 的且能被 factor 整除的整数""" + return math.floor(number / factor) * factor + + +def smart_resize(height, width, factor=28, min_pixels=56 * 56, max_pixels=14 * 14 * 4 * 1280, max_long_side=8192): + """缩放后图片满足以下条件: + 1. 长宽能被 factor 整除 + 2. pixels 总数被限制在 [min_pixels, max_pixels] 内 + 3. 最长边限制在 max_long_side 内 + 4. 保证其长宽比基本不变 + """ + if height < 2 or width < 2: + raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") + elif max(height, width) / min(height, width) > 200: + raise ValueError(f"absolute aspect ratio must be smaller than 100, got {height} / {width}") + + if max(height, width) > max_long_side: + beta = max(height, width) / max_long_side + height, width = int(height / beta), int(width / beta) + + h_bar = round_by_factor(height, factor) + w_bar = round_by_factor(width, factor) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def update_image_size_(image_ele: dict, min_tokens=1, max_tokens=12800, merge_base=2, patch_size=14): + """根据 min_tokens, max_tokens 更新 image_ele 的尺寸信息 + + Args: + image_ele (dict): + - image_ele["image"]: str 图片路径 + - image_ele["height"]: int 图片原始高度 + - image_ele["width"]: int 图片原始宽度 + + Returns: + 更新后的 image_ele, 新增如下 key-value pair + dict: + - image_ele["resized_height"]: int 输入到模型的真实高度 + - image_ele["resized_width"]: int 输入到模型的真实宽度 + - image_ele["seq_len"]: int 输入到模型所占的序列长度 + """ + height, width = image_ele["height"], image_ele["width"] + pixels_per_token = patch_size * patch_size * merge_base * merge_base + resized_height, resized_width = smart_resize( + height, + width, + factor=merge_base * patch_size, + min_pixels=pixels_per_token * min_tokens, + max_pixels=pixels_per_token * max_tokens, + max_long_side=50000, + ) + image_ele.update( + { + "resized_height": resized_height, + "resized_width": resized_width, + "seq_len": resized_height * resized_width // pixels_per_token + 2, + } + ) + return image_ele + + +def _convert_bbox_format_from_abs_origin(bbox, image_ele: dict, *, tgt_format: str): + x1, y1, x2, y2 = bbox + if tgt_format == "abs_origin": + new_bbox = [int(x1), int(y1), int(x2), int(y2)] + elif tgt_format == "abs_resized": + new_bbox = [ + int(x1 / image_ele["width"] * image_ele["resized_width"]), + int(y1 / image_ele["height"] * image_ele["resized_height"]), + int(x2 / image_ele["width"] * image_ele["resized_width"]), + int(y2 / image_ele["height"] * image_ele["resized_height"]), + ] + elif tgt_format == "qwen-vl": + new_bbox = [ + int(x1 / image_ele["width"] * 999), + int(y1 / image_ele["height"] * 999), + int(x2 / image_ele["width"] * 999), + int(y2 / image_ele["height"] * 999), + ] + elif tgt_format == "rel": + new_bbox = [ + float(x1 / image_ele["width"]), + float(y1 / image_ele["height"]), + float(x2 / image_ele["width"]), + float(y2 / image_ele["height"]), + ] + elif tgt_format == "molmo": + new_bbox = [ + round(x1 / image_ele["width"] * 100, ndigits=1), + round(y1 / image_ele["height"] * 100, ndigits=1), + round(x2 / image_ele["width"] * 100, ndigits=1), + round(y2 / image_ele["height"] * 100, ndigits=1), + ] + else: + assert False, f"Unknown tgt_format: {tgt_format}" + return new_bbox + + +def _convert_bbox_format_to_abs_origin(bbox, image_ele: dict, *, src_format: str): + x1, y1, x2, y2 = bbox + if src_format == "abs_origin": + new_bbox = [int(x1), int(y1), int(x2), int(y2)] + elif src_format == "abs_resized": + new_bbox = [ + int(x1 / image_ele["resized_width"] * image_ele["width"]), + int(y1 / image_ele["resized_height"] * image_ele["height"]), + int(x2 / image_ele["resized_width"] * image_ele["width"]), + int(y2 / image_ele["resized_height"] * image_ele["height"]), + ] + elif src_format == "qwen-vl": + new_bbox = [ + int(x1 / 999 * image_ele["width"]), + int(y1 / 999 * image_ele["height"]), + int(x2 / 999 * image_ele["width"]), + int(y2 / 999 * image_ele["height"]), + ] + elif src_format == "rel": + new_bbox = [ + int(x1 * image_ele["width"]), + int(y1 * image_ele["height"]), + int(x2 * image_ele["width"]), + int(y2 * image_ele["height"]), + ] + elif src_format == "molmo": + new_bbox = [ + int(x1 / 100 * image_ele["width"]), + int(y1 / 100 * image_ele["height"]), + int(x2 / 100 * image_ele["width"]), + int(y2 / 100 * image_ele["height"]), + ] + else: + assert False, f"Unknown src_format: {src_format}" + return new_bbox + + +def convert_bbox_format(bbox, image_ele: dict, *, src_format: str, tgt_format: str): + bbox_abs_origin = _convert_bbox_format_to_abs_origin(bbox, image_ele, src_format=src_format) + bbox_tgt_format = _convert_bbox_format_from_abs_origin(bbox_abs_origin, image_ele, tgt_format=tgt_format) + return bbox_tgt_format + + +def _convert_point_format_from_abs_origin(point, image_ele: dict, *, tgt_format: str): + x, y = point + if tgt_format == "abs_origin": + new_point = [int(x), int(y)] + elif tgt_format == "abs_resized": + new_point = [ + int(x / image_ele["width"] * image_ele["resized_width"]), + int(y / image_ele["height"] * image_ele["resized_height"]), + ] + elif tgt_format == "qwen-vl": + new_point = [ + int(x / image_ele["width"] * 999), + int(y / image_ele["height"] * 999), + ] + elif tgt_format == "rel": + new_point = [ + float(x / image_ele["width"]), + float(y / image_ele["height"]), + ] + elif tgt_format == "molmo": + new_point = [ + round(x / image_ele["width"] * 100, ndigits=1), + round(y / image_ele["height"] * 100, ndigits=1), + ] + else: + assert False, f"Unknown tgt_format: {tgt_format}" + return new_point + + +def _convert_point_format_to_abs_origin(point, image_ele: dict, *, src_format: str): + x, y = point + if src_format == "abs_origin": + new_point = [int(x), int(y)] + elif src_format == "abs_resized": + new_point = [ + int(x / image_ele["resized_width"] * image_ele["width"]), + int(y / image_ele["resized_height"] * image_ele["height"]), + ] + elif src_format == "qwen-vl": + new_point = [ + int(x / 999 * image_ele["width"]), + int(y / 999 * image_ele["height"]), + ] + elif src_format == "rel": + new_point = [ + int(x * image_ele["width"]), + int(y * image_ele["height"]), + ] + elif src_format == "molmo": + new_point = [ + int(x / 100 * image_ele["width"]), + int(y / 100 * image_ele["height"]), + ] + else: + assert False, f"Unknown src_format: {src_format}" + return new_point + + +def convert_point_format(point, image_ele: dict, *, src_format: str, tgt_format: str): + point_abs_origin = _convert_point_format_to_abs_origin(point, image_ele, src_format=src_format) + point_tgt_format = _convert_point_format_from_abs_origin(point_abs_origin, image_ele, tgt_format=tgt_format) + return point_tgt_format + + +__all__ = [ + "update_image_size_", + "convert_bbox_format", + "convert_point_format", +] + + +if __name__ == "__main__": + from PIL import Image + + def draw_point(image: Image.Image, point: list): + from copy import deepcopy + + from PIL import ImageDraw + + image = deepcopy(image) + image_draw = ImageDraw.Draw(image) + image_draw.ellipse([point[0] - 5, point[1] - 5, point[0] + 5, point[1] + 5], fill="red") + return image + + # image_ele = { + # "image": "http://ofasys-multimodal-wlcb-3.oss-cn-wulanchabu.aliyuncs.com/data/datacomp1b/image/19774238/7218d7ceb39e82e0cafc389f326e218da623a8f2.jpg", + # "height": 444, + # "width": 592, + # } + image_ele = { + "image": "46d5402b2c183f996f2a13cd2016af15.png", + "height": 1080, + "width": 1920, + } + point = [0.8379917184, 0.2087912088] # rel, keyboard 'k' in the image + + # image: Image.Image = Image.open(requests.get(image_ele["image"], stream=True).raw) + image: Image.Image = Image.open(image_ele["image"]) + assert image.width == image_ele["width"] and image.height == image_ele["height"], f"{image.size=}, {image_ele=}" + resized_image = image.resize((image_ele["resized_width"], image_ele["resized_height"])) + draw_point(image, [point[0] * image.width, point[1] * image.height]).save("image_1.png") + + image_ele = update_image_size_(image_ele) + point = convert_point_format(point, image_ele, src_format="rel", tgt_format="abs_resized") + print(f"{image_ele=}\n{point=}") + + + draw_point(resized_image, point).save("image_2.png") \ No newline at end of file diff --git a/run_multienv_qwen25vl.py b/run_multienv_qwen25vl.py new file mode 100644 index 0000000..5db9dc8 --- /dev/null +++ b/run_multienv_qwen25vl.py @@ -0,0 +1,362 @@ +"""Script to run end-to-end evaluation on the benchmark. +Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py. +""" + +import argparse +import datetime +import json +import logging +import os +import sys +from typing import List, Dict +import math +from tqdm import tqdm +from multiprocessing import Process, Manager +import lib_run_single +from desktop_env.desktop_env import DesktopEnv +from mm_agents.qwen25vl_agent import Qwen25VLAgent + +# import wandb + + +# Logger Configs {{{ # +logger = logging.getLogger() +logger.setLevel(logging.DEBUG) + +datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S") + +file_handler = logging.FileHandler( + os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8" +) +debug_handler = logging.FileHandler( + os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8" +) +stdout_handler = logging.StreamHandler(sys.stdout) +sdebug_handler = logging.FileHandler( + os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8" +) + +file_handler.setLevel(logging.INFO) +debug_handler.setLevel(logging.DEBUG) +stdout_handler.setLevel(logging.INFO) +sdebug_handler.setLevel(logging.DEBUG) + +formatter = logging.Formatter( + fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s" +) +file_handler.setFormatter(formatter) +debug_handler.setFormatter(formatter) +stdout_handler.setFormatter(formatter) +sdebug_handler.setFormatter(formatter) + +stdout_handler.addFilter(logging.Filter("desktopenv")) +sdebug_handler.addFilter(logging.Filter("desktopenv")) + +logger.addHandler(file_handler) +logger.addHandler(debug_handler) +logger.addHandler(stdout_handler) +logger.addHandler(sdebug_handler) +# }}} Logger Configs # + +logger = logging.getLogger("desktopenv.experiment") + + +def config() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run end-to-end evaluation on the benchmark" + ) + + # environment config + parser.add_argument("--path_to_vm", type=str, default=None) + parser.add_argument( + "--headless", action="store_true", help="Run in headless machine" + ) + parser.add_argument( + "--action_space", type=str, default="pyautogui", help="Action type" + ) + parser.add_argument( + "--observation_type", + choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"], + default="screenshot", + help="Observation type", + ) + parser.add_argument("--screen_width", type=int, default=1920) + parser.add_argument("--screen_height", type=int, default=1080) + parser.add_argument("--sleep_after_execution", type=float, default=2.0) + parser.add_argument("--max_steps", type=int, default=20) + + # agent config + parser.add_argument( + "--test_config_base_dir", type=str, default="evaluation_examples" + ) + + # lm config + parser.add_argument("--planner_model", type=str, default=None) + parser.add_argument("--executor_model", type=str, default="aguvis-s1-s2-agentnet0105-mo5") + parser.add_argument("--temperature", type=float, default=0) + parser.add_argument("--top_p", type=float, default=0.9) + parser.add_argument("--max_tokens", type=int, default=1500) + parser.add_argument("--stop_token", type=str, default=None) + + # example config + parser.add_argument("--domain", type=str, default="all") + parser.add_argument( + "--test_all_meta_path", type=str, default="evaluation_examples/test_all.json" + ) + + # logging related + parser.add_argument("--result_dir", type=str, default="./results") + parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel") + + args = parser.parse_args() + return args + + +def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]: + """Distribute tasks evenly across environments.""" + # Flatten the tasks into a single list + all_tasks = [] + for domain, examples in test_all_meta.items(): + for example_id in examples: + all_tasks.append((domain, example_id)) + + # Calculate tasks per environment + tasks_per_env = math.ceil(len(all_tasks) / num_envs) + + # Distribute tasks + distributed_tasks = [] + for i in range(num_envs): + env_tasks = {} + start_idx = i * tasks_per_env + end_idx = min((i + 1) * tasks_per_env, len(all_tasks)) + + for domain, example_id in all_tasks[start_idx:end_idx]: + if domain not in env_tasks: + env_tasks[domain] = [] + env_tasks[domain].append(example_id) + + distributed_tasks.append(env_tasks) + + return distributed_tasks + + + +def run_env_tasks(env_idx: int, env: DesktopEnv, agent, env_tasks: dict, args: argparse.Namespace, shared_scores: list): + """Run tasks for a single environment.""" + logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}") + + for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"): + for example_id in tqdm(env_tasks[domain], desc="Example", leave=False): + config_file = os.path.join( + args.test_config_base_dir, f"examples/{domain}/{example_id}.json" + ) + with open(config_file, "r", encoding="utf-8") as f: + example = json.load(f) + + logger.info(f"[Env {env_idx+1}][Domain]: {domain}") + logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}") + logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}") + + example_result_dir = os.path.join( + args.result_dir, + args.action_space, + args.observation_type, + "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model), + domain, + example_id, + ) + os.makedirs(example_result_dir, exist_ok=True) + + try: + lib_run_single.run_single_example( + agent, + env, + example, + args.max_steps, + example["instruction"], + args, + example_result_dir, + shared_scores, + ) + except Exception as e: + logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}") + env.controller.end_recording( + os.path.join(example_result_dir, "recording.mp4") + ) + with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: + f.write( + json.dumps( + {"Error": f"Time limit exceeded in {domain}/{example_id}"} + ) + ) + f.write("\n") + + env.close() + + +def test(args: argparse.Namespace, test_all_meta: dict) -> None: + logger.info("Args: %s", args) + + distributed_tasks = distribute_tasks(test_all_meta, args.num_envs) + + # First, set up all environments + logger.info("Setting up all environments...") + envs = [] + agents = [] + + for env_idx in range(args.num_envs): + logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}") + + agent = Qwen25VLAgent( + planner_model=args.planner_model, + executor_model=args.executor_model, + max_tokens=args.max_tokens, + top_p=args.top_p, + temperature=args.temperature, + action_space=args.action_space, + ) + agents.append(agent) + + env = DesktopEnv( + path_to_vm=args.path_to_vm, + action_space=agent.action_space, + screen_size=(args.screen_width, args.screen_height), + headless=args.headless, + os_type="Ubuntu", + require_a11y_tree=args.observation_type + in ["a11y_tree", "screenshot_a11y_tree", "som"], + provider_name="docker" + ) + envs.append(env) + + logger.info("All environments are ready. Starting parallel task execution...") + + # Create a shared list for scores across processes + with Manager() as manager: + shared_scores = manager.list() + + # Create and start processes for each environment + processes = [] + for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)): + p = Process( + target=run_env_tasks, + args=(env_idx, env, agent, env_tasks, args, shared_scores) + ) + processes.append(p) + p.start() + + # Wait for all processes to complete + for p in processes: + p.join() + + # Convert shared list to regular list + scores = list(shared_scores) + + logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}") + + +def get_unfinished( + action_space, use_model, observation_type, result_dir, total_file_json +): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + + if not os.path.exists(target_dir): + return total_file_json + + finished = {} + for domain in os.listdir(target_dir): + finished[domain] = [] + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + if example_id == "onboard": + continue + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" not in os.listdir(example_path): + # empty all files under example_id + for file in os.listdir(example_path): + os.remove(os.path.join(example_path, file)) + else: + finished[domain].append(example_id) + + if not finished: + return total_file_json + + for domain, examples in finished.items(): + if domain in total_file_json: + total_file_json[domain] = [ + x for x in total_file_json[domain] if x not in examples + ] + + return total_file_json + + +def get_result(action_space, use_model, observation_type, result_dir, total_file_json): + target_dir = os.path.join(result_dir, action_space, observation_type, use_model) + if not os.path.exists(target_dir): + print("New experiment, no result yet.") + return None + + all_result = [] + + for domain in os.listdir(target_dir): + domain_path = os.path.join(target_dir, domain) + if os.path.isdir(domain_path): + for example_id in os.listdir(domain_path): + example_path = os.path.join(domain_path, example_id) + if os.path.isdir(example_path): + if "result.txt" in os.listdir(example_path): + # empty all files under example_id + try: + all_result.append( + float( + open( + os.path.join(example_path, "result.txt"), "r" + ).read() + ) + ) + except: + all_result.append(0.0) + + if not all_result: + print("New experiment, no result yet.") + return None + else: + print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%") + return all_result + + +if __name__ == "__main__": + ####### The complete version of the list of examples ####### + os.environ["TOKENIZERS_PARALLELISM"] = "false" + args = config() + + with open(args.test_all_meta_path, "r", encoding="utf-8") as f: + test_all_meta = json.load(f) + + if args.domain != "all": + test_all_meta = {args.domain: test_all_meta[args.domain]} + + exp_name = "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model) + + test_file_list = get_unfinished( + args.action_space, + exp_name, + args.observation_type, + args.result_dir, + test_all_meta, + ) + left_info = "" + for domain in test_file_list: + left_info += f"{domain}: {len(test_file_list[domain])}\n" + logger.info(f"Left tasks:\n{left_info}") + + get_result( + args.action_space, + exp_name, + args.observation_type, + args.result_dir, + test_all_meta, + ) + test(args, test_file_list)