582
mm_agents/qwen25vl_agent.py
Normal file
582
mm_agents/qwen25vl_agent.py
Normal file
@@ -0,0 +1,582 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
import backoff
|
||||||
|
import openai
|
||||||
|
from PIL import Image
|
||||||
|
from requests.exceptions import SSLError
|
||||||
|
from google.api_core.exceptions import (
|
||||||
|
InvalidArgument,
|
||||||
|
ResourceExhausted,
|
||||||
|
InternalServerError,
|
||||||
|
BadRequest,
|
||||||
|
)
|
||||||
|
from mm_agents.utils.qwen_vl_utils import smart_resize
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
logger = None
|
||||||
|
|
||||||
|
MAX_RETRY_TIMES = 5
|
||||||
|
|
||||||
|
def encode_image(image_content):
|
||||||
|
return base64.b64encode(image_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def process_image(image_bytes):
|
||||||
|
"""
|
||||||
|
Process an image for Qwen VL models.
|
||||||
|
Resize the image to dimensions expected by the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytes: Raw image bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded image string of the processed image
|
||||||
|
"""
|
||||||
|
# Open image from bytes
|
||||||
|
image = Image.open(BytesIO(image_bytes))
|
||||||
|
width, height = image.size
|
||||||
|
|
||||||
|
# Calculate resized dimensions
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height=height,
|
||||||
|
width=width
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resize the image
|
||||||
|
image = image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
# Convert to bytes
|
||||||
|
buffer = BytesIO()
|
||||||
|
image.save(buffer, format="PNG")
|
||||||
|
processed_bytes = buffer.getvalue()
|
||||||
|
|
||||||
|
# Return base64 encoded string
|
||||||
|
return base64.b64encode(processed_bytes).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25VLAgent:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
platform="ubuntu",
|
||||||
|
planner_model="gpt-4o",
|
||||||
|
executor_model="qwen2.5vl",
|
||||||
|
max_tokens=1500,
|
||||||
|
top_p=0.9,
|
||||||
|
temperature=0.5,
|
||||||
|
action_space="pyautogui",
|
||||||
|
observation_type="screenshot",
|
||||||
|
history_n=4, # Number of previous interactions to include in full detail
|
||||||
|
):
|
||||||
|
self.platform = platform
|
||||||
|
self.planner_model = planner_model
|
||||||
|
self.executor_model = executor_model
|
||||||
|
assert self.executor_model is not None, "Executor model cannot be None"
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.top_p = top_p
|
||||||
|
self.temperature = temperature
|
||||||
|
self.action_space = action_space
|
||||||
|
self.observation_type = observation_type
|
||||||
|
self.history_n = history_n # Control how many previous interactions to include
|
||||||
|
assert action_space in ["pyautogui"], "Invalid action space"
|
||||||
|
assert observation_type in ["screenshot"], "Invalid observation type"
|
||||||
|
self.thoughts = []
|
||||||
|
self.actions = []
|
||||||
|
self.observations = []
|
||||||
|
self.responses = [] # Store model responses
|
||||||
|
self.screenshots = [] # Store processed screenshots
|
||||||
|
|
||||||
|
def predict(self, instruction: str, obs: Dict) -> List:
|
||||||
|
"""
|
||||||
|
Predict the next action(s) based on the current observation.
|
||||||
|
"""
|
||||||
|
# Process the screenshot image
|
||||||
|
screenshot_bytes = obs["screenshot"]
|
||||||
|
|
||||||
|
# Display original dimensions
|
||||||
|
image = Image.open(BytesIO(screenshot_bytes))
|
||||||
|
width, height = image.size
|
||||||
|
print(f"Original screen resolution: {width}x{height}")
|
||||||
|
|
||||||
|
# Process the image
|
||||||
|
processed_image = process_image(screenshot_bytes)
|
||||||
|
processed_img = Image.open(BytesIO(base64.b64decode(processed_image)))
|
||||||
|
processed_width, processed_height = processed_img.size
|
||||||
|
print(f"Processed image resolution: {processed_width}x{processed_height}")
|
||||||
|
|
||||||
|
# Save the current screenshot to history
|
||||||
|
self.screenshots.append(processed_image)
|
||||||
|
|
||||||
|
# Calculate history window start index
|
||||||
|
current_step = len(self.actions)
|
||||||
|
history_start_idx = max(0, current_step - self.history_n)
|
||||||
|
|
||||||
|
# Build previous actions string - only include actions outside the history window
|
||||||
|
previous_actions = []
|
||||||
|
for i in range(history_start_idx):
|
||||||
|
if i < len(self.actions):
|
||||||
|
previous_actions.append(f"Step {i+1}: {self.actions[i]}")
|
||||||
|
previous_actions_str = "\n".join(previous_actions) if previous_actions else "None"
|
||||||
|
|
||||||
|
# System prompt with tool definition
|
||||||
|
tools_def = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name_for_human": "computer_use",
|
||||||
|
"name": "computer_use",
|
||||||
|
"description": "Use a mouse and keyboard to interact with a computer, and take screenshots.",
|
||||||
|
"parameters": {
|
||||||
|
"properties": {
|
||||||
|
"action": {
|
||||||
|
"description": "The action to perform.",
|
||||||
|
"enum": ["key", "type", "mouse_move", "left_click", "left_click_drag",
|
||||||
|
"right_click", "middle_click", "double_click", "scroll", "wait", "terminate"],
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"keys": {"description": "Required only by `action=key`.", "type": "array"},
|
||||||
|
"text": {"description": "Required only by `action=type`.", "type": "string"},
|
||||||
|
"coordinate": {"description": "The x,y coordinates for mouse actions.", "type": "array"},
|
||||||
|
"pixels": {"description": "The amount of scrolling.", "type": "number"},
|
||||||
|
"time": {"description": "The seconds to wait.", "type": "number"},
|
||||||
|
"status": {
|
||||||
|
"description": "The status of the task.",
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["success", "failure"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["action"],
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
"args_format": "Format the arguments as a JSON object."
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
system_prompt = """You are a helpful assistant
|
||||||
|
|
||||||
|
# Tools
|
||||||
|
|
||||||
|
You may call one or more functions to assist with the user query.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
""" + json.dumps(tools_def) + """
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call>
|
||||||
|
{"name": <function-name>, "arguments": <args-json-object>}
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
# Create instruction prompt
|
||||||
|
instruction_prompt = f"""
|
||||||
|
Please generate the next move according to the UI screenshot, instruction and previous actions.
|
||||||
|
|
||||||
|
Instruction: {instruction}
|
||||||
|
|
||||||
|
Previous actions:
|
||||||
|
{previous_actions_str}"""
|
||||||
|
|
||||||
|
# Initialize messages with system prompt
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": system_prompt
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add history responses and images within the history window
|
||||||
|
history_len = min(self.history_n, len(self.responses))
|
||||||
|
if history_len > 0:
|
||||||
|
# Only include the most recent history_n steps
|
||||||
|
history_responses = self.responses[-history_len:]
|
||||||
|
history_screenshots = self.screenshots[-history_len-1:-1] # Include one more for the previous screenshot
|
||||||
|
|
||||||
|
# Add history in conversation format
|
||||||
|
for idx in range(history_len):
|
||||||
|
# Add the screenshot (user message)
|
||||||
|
if idx < len(history_screenshots):
|
||||||
|
screenshot_b64 = history_screenshots[idx]
|
||||||
|
|
||||||
|
# If this is the first history item, include instruction prompt
|
||||||
|
if idx == 0:
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{screenshot_b64}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": instruction_prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{screenshot_b64}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add the action and response (assistant message)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": history_responses[idx]}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add the current screenshot without instruction (since we already have history)
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{processed_image}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# If no history, just add current screenshot with instruction prompt
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": f"data:image/png;base64,{processed_image}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": instruction_prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
# append_text = f"""Step {current_step+1}: Thought:"""
|
||||||
|
append_text = f"""Thought:"""
|
||||||
|
messages.append({"role": "assistant", "content": [{"type": "text", "text": append_text}]})
|
||||||
|
|
||||||
|
# Call the LLM
|
||||||
|
response = self.call_llm(
|
||||||
|
{
|
||||||
|
"model": self.executor_model,
|
||||||
|
"messages": messages,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
},
|
||||||
|
self.executor_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Qwen25VL Output: {response}")
|
||||||
|
|
||||||
|
# Save response to history
|
||||||
|
self.responses.append(response)
|
||||||
|
|
||||||
|
# Parse response and extract pyautogui code
|
||||||
|
low_level_instruction, pyautogui_code = self.parse_response(
|
||||||
|
response,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
processed_width,
|
||||||
|
processed_height
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Low level instruction: {low_level_instruction}")
|
||||||
|
logger.info(f"Pyautogui code: {pyautogui_code}")
|
||||||
|
|
||||||
|
# Add the action to history
|
||||||
|
self.actions.append(low_level_instruction)
|
||||||
|
|
||||||
|
return response, pyautogui_code
|
||||||
|
|
||||||
|
def parse_response(self, response: str, original_width: int = None, original_height: int = None,
|
||||||
|
processed_width: int = None, processed_height: int = None) -> Tuple[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Parse LLM response and convert it to low level action and pyautogui code.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Raw response string from the model
|
||||||
|
original_width: Width of the original screenshot
|
||||||
|
original_height: Height of the original screenshot
|
||||||
|
processed_width: Width of the processed image
|
||||||
|
processed_height: Height of the processed image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (low_level_instruction, list of pyautogui_commands)
|
||||||
|
"""
|
||||||
|
low_level_instruction = ""
|
||||||
|
pyautogui_code = []
|
||||||
|
|
||||||
|
if response is None or not response.strip():
|
||||||
|
return low_level_instruction, pyautogui_code
|
||||||
|
|
||||||
|
# Define function to adjust coordinates based on original and processed dimensions
|
||||||
|
def adjust_coordinates(x: float, y: float) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Adjust coordinates from processed image dimensions to original image dimensions.
|
||||||
|
"""
|
||||||
|
if all([original_width, original_height, processed_width, processed_height]):
|
||||||
|
# Calculate the scale factors between original and processed images
|
||||||
|
x_scale = original_width / processed_width
|
||||||
|
y_scale = original_height / processed_height
|
||||||
|
|
||||||
|
# Apply scaling to get coordinates in original image space
|
||||||
|
adjusted_x = int(x * x_scale)
|
||||||
|
adjusted_y = int(y * y_scale)
|
||||||
|
|
||||||
|
return adjusted_x, adjusted_y
|
||||||
|
else:
|
||||||
|
# If any dimension is missing, return the original coordinates
|
||||||
|
return int(x), int(y)
|
||||||
|
|
||||||
|
# Define inner function to process tool calls
|
||||||
|
def process_tool_call(json_str: str) -> None:
|
||||||
|
"""Process a single tool call JSON string."""
|
||||||
|
try:
|
||||||
|
# Parse the JSON
|
||||||
|
tool_call = json.loads(json_str)
|
||||||
|
if tool_call.get("name") == "computer_use":
|
||||||
|
# Convert computer_use actions to pyautogui commands
|
||||||
|
args = tool_call["arguments"]
|
||||||
|
action = args["action"]
|
||||||
|
|
||||||
|
if action == "left_click":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
pyautogui_code.append(f"pyautogui.click({adj_x}, {adj_y})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.click()")
|
||||||
|
|
||||||
|
elif action == "right_click":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
pyautogui_code.append(f"pyautogui.rightClick({adj_x}, {adj_y})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.rightClick()")
|
||||||
|
|
||||||
|
elif action == "middle_click":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
pyautogui_code.append(f"pyautogui.middleClick({adj_x}, {adj_y})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.middleClick()")
|
||||||
|
|
||||||
|
elif action == "double_click":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
pyautogui_code.append(f"pyautogui.doubleClick({adj_x}, {adj_y})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.doubleClick()")
|
||||||
|
|
||||||
|
elif action == "type":
|
||||||
|
text = args.get("text", "")
|
||||||
|
pyautogui_code.append(f"pyautogui.typewrite('{text}')")
|
||||||
|
|
||||||
|
elif action == "key":
|
||||||
|
keys = args.get("keys", [])
|
||||||
|
# Fix possible formatting issues in the keys parameter
|
||||||
|
if isinstance(keys, list):
|
||||||
|
# Clean up any formatting issues in the keys
|
||||||
|
cleaned_keys = []
|
||||||
|
for key in keys:
|
||||||
|
# Check if the key has the "keys=[" prefix or "]" suffix
|
||||||
|
if isinstance(key, str):
|
||||||
|
# Remove "keys=[" prefix if present
|
||||||
|
if key.startswith("keys=["):
|
||||||
|
key = key[6:]
|
||||||
|
# Remove "]" suffix if present
|
||||||
|
if key.endswith("]"):
|
||||||
|
key = key[:-1]
|
||||||
|
# Handle case where string contains representation of list items
|
||||||
|
if key.startswith("['") or key.startswith("[\""):
|
||||||
|
key = key[2:] if len(key) > 2 else key
|
||||||
|
if key.endswith("']") or key.endswith("\"]"):
|
||||||
|
key = key[:-2] if len(key) > 2 else key
|
||||||
|
# Strip any extra whitespace
|
||||||
|
key = key.strip()
|
||||||
|
# Add to cleaned keys
|
||||||
|
cleaned_keys.append(key)
|
||||||
|
else:
|
||||||
|
cleaned_keys.append(key)
|
||||||
|
keys = cleaned_keys
|
||||||
|
|
||||||
|
# Format the keys for hotkey or press command
|
||||||
|
keys_str = ", ".join([f"'{key}'" for key in keys])
|
||||||
|
if len(keys) > 1:
|
||||||
|
pyautogui_code.append(f"pyautogui.hotkey({keys_str})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append(f"pyautogui.press({keys_str})")
|
||||||
|
|
||||||
|
elif action == "scroll":
|
||||||
|
pixels = args.get("pixels", 0)
|
||||||
|
pyautogui_code.append(f"pyautogui.scroll({pixels})")
|
||||||
|
|
||||||
|
elif action == "wait":
|
||||||
|
pyautogui_code.append("WAIT") # Special code for wait action
|
||||||
|
|
||||||
|
elif action == "terminate":
|
||||||
|
pyautogui_code.append("DONE") # Special code for termination
|
||||||
|
|
||||||
|
elif action == "mouse_move":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
pyautogui_code.append(f"pyautogui.moveTo({adj_x}, {adj_y})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.moveTo(0, 0)")
|
||||||
|
|
||||||
|
elif action == "left_click_drag":
|
||||||
|
if "coordinate" in args:
|
||||||
|
x, y = args["coordinate"]
|
||||||
|
adj_x, adj_y = adjust_coordinates(x, y)
|
||||||
|
duration = args.get("duration", 0.5)
|
||||||
|
pyautogui_code.append(f"pyautogui.dragTo({adj_x}, {adj_y}, duration={duration})")
|
||||||
|
else:
|
||||||
|
pyautogui_code.append("pyautogui.dragTo(0, 0)")
|
||||||
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
|
logger.error(f"Failed to parse tool call: {e}")
|
||||||
|
|
||||||
|
# Parse the response line by line
|
||||||
|
lines = response.split('\n')
|
||||||
|
inside_tool_call = False
|
||||||
|
current_tool_call = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract low-level instruction from lines starting with "Action:" or similar
|
||||||
|
if line.lower().startswith(("action:", "step", "i will", "i'll", "now i")):
|
||||||
|
if not low_level_instruction:
|
||||||
|
# Only store the first action description as low level instruction
|
||||||
|
low_level_instruction = line
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle lines inside tool call markers
|
||||||
|
if line.startswith("<tool_call>"):
|
||||||
|
inside_tool_call = True
|
||||||
|
continue
|
||||||
|
elif line.startswith("</tool_call>"):
|
||||||
|
if current_tool_call:
|
||||||
|
# Process the collected tool call
|
||||||
|
process_tool_call("\n".join(current_tool_call))
|
||||||
|
current_tool_call = []
|
||||||
|
inside_tool_call = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
if inside_tool_call:
|
||||||
|
current_tool_call.append(line)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to parse individual lines as JSON
|
||||||
|
if line.startswith("{") and line.endswith("}"):
|
||||||
|
try:
|
||||||
|
json_obj = json.loads(line)
|
||||||
|
if "name" in json_obj and "arguments" in json_obj:
|
||||||
|
process_tool_call(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Process any remaining tool call content
|
||||||
|
if current_tool_call:
|
||||||
|
process_tool_call("\n".join(current_tool_call))
|
||||||
|
|
||||||
|
# If we still don't have a low-level instruction, generate a default one
|
||||||
|
if not low_level_instruction and len(pyautogui_code) > 0:
|
||||||
|
action_type = pyautogui_code[0].split(".", 1)[1].split("(", 1)[0]
|
||||||
|
low_level_instruction = f"Performing {action_type} action"
|
||||||
|
|
||||||
|
return low_level_instruction, pyautogui_code
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.constant,
|
||||||
|
# here you should add more model exceptions as you want,
|
||||||
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
||||||
|
# because we want to catch this kind of Exception in the outside to ensure
|
||||||
|
# each example won't exceed the time limit
|
||||||
|
(
|
||||||
|
# General exceptions
|
||||||
|
SSLError,
|
||||||
|
# OpenAI exceptions
|
||||||
|
openai.RateLimitError,
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.InternalServerError,
|
||||||
|
# Google exceptions
|
||||||
|
InvalidArgument,
|
||||||
|
ResourceExhausted,
|
||||||
|
InternalServerError,
|
||||||
|
BadRequest,
|
||||||
|
# Groq exceptions
|
||||||
|
# todo: check
|
||||||
|
),
|
||||||
|
interval=30,
|
||||||
|
max_tries=10,
|
||||||
|
)
|
||||||
|
def call_llm(self, payload, model):
|
||||||
|
messages = payload["messages"]
|
||||||
|
base_url = "your_base_url"
|
||||||
|
api_key = "your_api_key"
|
||||||
|
|
||||||
|
client = openai.OpenAI(
|
||||||
|
base_url=base_url,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(MAX_RETRY_TIMES):
|
||||||
|
logger.info("Generating content with Qwen model: %s", model)
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calling Qwen model: {e}")
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def reset(self, _logger=None):
|
||||||
|
global logger
|
||||||
|
logger = (_logger if _logger is not None else
|
||||||
|
logging.getLogger("desktopenv.qwen25vl_agent"))
|
||||||
|
|
||||||
|
self.thoughts = []
|
||||||
|
self.action_descriptions = []
|
||||||
|
self.actions = []
|
||||||
|
self.observations = []
|
||||||
|
self.responses = [] # Reset responses
|
||||||
|
self.screenshots = [] # Reset screenshots
|
||||||
271
mm_agents/utils/qwen_vl_utils.py
Normal file
271
mm_agents/utils/qwen_vl_utils.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def round_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""返回最接近 number 的且能被 factor 整除的整数"""
|
||||||
|
return round(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""返回大于等于 number 的且能被 factor 整除的整数"""
|
||||||
|
return math.ceil(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def floor_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""返回小于等于 number 的且能被 factor 整除的整数"""
|
||||||
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def smart_resize(height, width, factor=28, min_pixels=56 * 56, max_pixels=14 * 14 * 4 * 1280, max_long_side=8192):
|
||||||
|
"""缩放后图片满足以下条件:
|
||||||
|
1. 长宽能被 factor 整除
|
||||||
|
2. pixels 总数被限制在 [min_pixels, max_pixels] 内
|
||||||
|
3. 最长边限制在 max_long_side 内
|
||||||
|
4. 保证其长宽比基本不变
|
||||||
|
"""
|
||||||
|
if height < 2 or width < 2:
|
||||||
|
raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
|
||||||
|
elif max(height, width) / min(height, width) > 200:
|
||||||
|
raise ValueError(f"absolute aspect ratio must be smaller than 100, got {height} / {width}")
|
||||||
|
|
||||||
|
if max(height, width) > max_long_side:
|
||||||
|
beta = max(height, width) / max_long_side
|
||||||
|
height, width = int(height / beta), int(width / beta)
|
||||||
|
|
||||||
|
h_bar = round_by_factor(height, factor)
|
||||||
|
w_bar = round_by_factor(width, factor)
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = floor_by_factor(height / beta, factor)
|
||||||
|
w_bar = floor_by_factor(width / beta, factor)
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = ceil_by_factor(height * beta, factor)
|
||||||
|
w_bar = ceil_by_factor(width * beta, factor)
|
||||||
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
|
||||||
|
def update_image_size_(image_ele: dict, min_tokens=1, max_tokens=12800, merge_base=2, patch_size=14):
|
||||||
|
"""根据 min_tokens, max_tokens 更新 image_ele 的尺寸信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_ele (dict):
|
||||||
|
- image_ele["image"]: str 图片路径
|
||||||
|
- image_ele["height"]: int 图片原始高度
|
||||||
|
- image_ele["width"]: int 图片原始宽度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
更新后的 image_ele, 新增如下 key-value pair
|
||||||
|
dict:
|
||||||
|
- image_ele["resized_height"]: int 输入到模型的真实高度
|
||||||
|
- image_ele["resized_width"]: int 输入到模型的真实宽度
|
||||||
|
- image_ele["seq_len"]: int 输入到模型所占的序列长度
|
||||||
|
"""
|
||||||
|
height, width = image_ele["height"], image_ele["width"]
|
||||||
|
pixels_per_token = patch_size * patch_size * merge_base * merge_base
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
factor=merge_base * patch_size,
|
||||||
|
min_pixels=pixels_per_token * min_tokens,
|
||||||
|
max_pixels=pixels_per_token * max_tokens,
|
||||||
|
max_long_side=50000,
|
||||||
|
)
|
||||||
|
image_ele.update(
|
||||||
|
{
|
||||||
|
"resized_height": resized_height,
|
||||||
|
"resized_width": resized_width,
|
||||||
|
"seq_len": resized_height * resized_width // pixels_per_token + 2,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return image_ele
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_bbox_format_from_abs_origin(bbox, image_ele: dict, *, tgt_format: str):
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
if tgt_format == "abs_origin":
|
||||||
|
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||||||
|
elif tgt_format == "abs_resized":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 / image_ele["width"] * image_ele["resized_width"]),
|
||||||
|
int(y1 / image_ele["height"] * image_ele["resized_height"]),
|
||||||
|
int(x2 / image_ele["width"] * image_ele["resized_width"]),
|
||||||
|
int(y2 / image_ele["height"] * image_ele["resized_height"]),
|
||||||
|
]
|
||||||
|
elif tgt_format == "qwen-vl":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 / image_ele["width"] * 999),
|
||||||
|
int(y1 / image_ele["height"] * 999),
|
||||||
|
int(x2 / image_ele["width"] * 999),
|
||||||
|
int(y2 / image_ele["height"] * 999),
|
||||||
|
]
|
||||||
|
elif tgt_format == "rel":
|
||||||
|
new_bbox = [
|
||||||
|
float(x1 / image_ele["width"]),
|
||||||
|
float(y1 / image_ele["height"]),
|
||||||
|
float(x2 / image_ele["width"]),
|
||||||
|
float(y2 / image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif tgt_format == "molmo":
|
||||||
|
new_bbox = [
|
||||||
|
round(x1 / image_ele["width"] * 100, ndigits=1),
|
||||||
|
round(y1 / image_ele["height"] * 100, ndigits=1),
|
||||||
|
round(x2 / image_ele["width"] * 100, ndigits=1),
|
||||||
|
round(y2 / image_ele["height"] * 100, ndigits=1),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown tgt_format: {tgt_format}"
|
||||||
|
return new_bbox
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_bbox_format_to_abs_origin(bbox, image_ele: dict, *, src_format: str):
|
||||||
|
x1, y1, x2, y2 = bbox
|
||||||
|
if src_format == "abs_origin":
|
||||||
|
new_bbox = [int(x1), int(y1), int(x2), int(y2)]
|
||||||
|
elif src_format == "abs_resized":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 / image_ele["resized_width"] * image_ele["width"]),
|
||||||
|
int(y1 / image_ele["resized_height"] * image_ele["height"]),
|
||||||
|
int(x2 / image_ele["resized_width"] * image_ele["width"]),
|
||||||
|
int(y2 / image_ele["resized_height"] * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "qwen-vl":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 / 999 * image_ele["width"]),
|
||||||
|
int(y1 / 999 * image_ele["height"]),
|
||||||
|
int(x2 / 999 * image_ele["width"]),
|
||||||
|
int(y2 / 999 * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "rel":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 * image_ele["width"]),
|
||||||
|
int(y1 * image_ele["height"]),
|
||||||
|
int(x2 * image_ele["width"]),
|
||||||
|
int(y2 * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "molmo":
|
||||||
|
new_bbox = [
|
||||||
|
int(x1 / 100 * image_ele["width"]),
|
||||||
|
int(y1 / 100 * image_ele["height"]),
|
||||||
|
int(x2 / 100 * image_ele["width"]),
|
||||||
|
int(y2 / 100 * image_ele["height"]),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown src_format: {src_format}"
|
||||||
|
return new_bbox
|
||||||
|
|
||||||
|
|
||||||
|
def convert_bbox_format(bbox, image_ele: dict, *, src_format: str, tgt_format: str):
|
||||||
|
bbox_abs_origin = _convert_bbox_format_to_abs_origin(bbox, image_ele, src_format=src_format)
|
||||||
|
bbox_tgt_format = _convert_bbox_format_from_abs_origin(bbox_abs_origin, image_ele, tgt_format=tgt_format)
|
||||||
|
return bbox_tgt_format
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_point_format_from_abs_origin(point, image_ele: dict, *, tgt_format: str):
|
||||||
|
x, y = point
|
||||||
|
if tgt_format == "abs_origin":
|
||||||
|
new_point = [int(x), int(y)]
|
||||||
|
elif tgt_format == "abs_resized":
|
||||||
|
new_point = [
|
||||||
|
int(x / image_ele["width"] * image_ele["resized_width"]),
|
||||||
|
int(y / image_ele["height"] * image_ele["resized_height"]),
|
||||||
|
]
|
||||||
|
elif tgt_format == "qwen-vl":
|
||||||
|
new_point = [
|
||||||
|
int(x / image_ele["width"] * 999),
|
||||||
|
int(y / image_ele["height"] * 999),
|
||||||
|
]
|
||||||
|
elif tgt_format == "rel":
|
||||||
|
new_point = [
|
||||||
|
float(x / image_ele["width"]),
|
||||||
|
float(y / image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif tgt_format == "molmo":
|
||||||
|
new_point = [
|
||||||
|
round(x / image_ele["width"] * 100, ndigits=1),
|
||||||
|
round(y / image_ele["height"] * 100, ndigits=1),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown tgt_format: {tgt_format}"
|
||||||
|
return new_point
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_point_format_to_abs_origin(point, image_ele: dict, *, src_format: str):
|
||||||
|
x, y = point
|
||||||
|
if src_format == "abs_origin":
|
||||||
|
new_point = [int(x), int(y)]
|
||||||
|
elif src_format == "abs_resized":
|
||||||
|
new_point = [
|
||||||
|
int(x / image_ele["resized_width"] * image_ele["width"]),
|
||||||
|
int(y / image_ele["resized_height"] * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "qwen-vl":
|
||||||
|
new_point = [
|
||||||
|
int(x / 999 * image_ele["width"]),
|
||||||
|
int(y / 999 * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "rel":
|
||||||
|
new_point = [
|
||||||
|
int(x * image_ele["width"]),
|
||||||
|
int(y * image_ele["height"]),
|
||||||
|
]
|
||||||
|
elif src_format == "molmo":
|
||||||
|
new_point = [
|
||||||
|
int(x / 100 * image_ele["width"]),
|
||||||
|
int(y / 100 * image_ele["height"]),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
assert False, f"Unknown src_format: {src_format}"
|
||||||
|
return new_point
|
||||||
|
|
||||||
|
|
||||||
|
def convert_point_format(point, image_ele: dict, *, src_format: str, tgt_format: str):
|
||||||
|
point_abs_origin = _convert_point_format_to_abs_origin(point, image_ele, src_format=src_format)
|
||||||
|
point_tgt_format = _convert_point_format_from_abs_origin(point_abs_origin, image_ele, tgt_format=tgt_format)
|
||||||
|
return point_tgt_format
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"update_image_size_",
|
||||||
|
"convert_bbox_format",
|
||||||
|
"convert_point_format",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def draw_point(image: Image.Image, point: list):
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from PIL import ImageDraw
|
||||||
|
|
||||||
|
image = deepcopy(image)
|
||||||
|
image_draw = ImageDraw.Draw(image)
|
||||||
|
image_draw.ellipse([point[0] - 5, point[1] - 5, point[0] + 5, point[1] + 5], fill="red")
|
||||||
|
return image
|
||||||
|
|
||||||
|
# image_ele = {
|
||||||
|
# "image": "http://ofasys-multimodal-wlcb-3.oss-cn-wulanchabu.aliyuncs.com/data/datacomp1b/image/19774238/7218d7ceb39e82e0cafc389f326e218da623a8f2.jpg",
|
||||||
|
# "height": 444,
|
||||||
|
# "width": 592,
|
||||||
|
# }
|
||||||
|
image_ele = {
|
||||||
|
"image": "46d5402b2c183f996f2a13cd2016af15.png",
|
||||||
|
"height": 1080,
|
||||||
|
"width": 1920,
|
||||||
|
}
|
||||||
|
point = [0.8379917184, 0.2087912088] # rel, keyboard 'k' in the image
|
||||||
|
|
||||||
|
# image: Image.Image = Image.open(requests.get(image_ele["image"], stream=True).raw)
|
||||||
|
image: Image.Image = Image.open(image_ele["image"])
|
||||||
|
assert image.width == image_ele["width"] and image.height == image_ele["height"], f"{image.size=}, {image_ele=}"
|
||||||
|
resized_image = image.resize((image_ele["resized_width"], image_ele["resized_height"]))
|
||||||
|
draw_point(image, [point[0] * image.width, point[1] * image.height]).save("image_1.png")
|
||||||
|
|
||||||
|
image_ele = update_image_size_(image_ele)
|
||||||
|
point = convert_point_format(point, image_ele, src_format="rel", tgt_format="abs_resized")
|
||||||
|
print(f"{image_ele=}\n{point=}")
|
||||||
|
|
||||||
|
|
||||||
|
draw_point(resized_image, point).save("image_2.png")
|
||||||
362
run_multienv_qwen25vl.py
Normal file
362
run_multienv_qwen25vl.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
Utils and basic architecture credit to https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import List, Dict
|
||||||
|
import math
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing import Process, Manager
|
||||||
|
import lib_run_single
|
||||||
|
from desktop_env.desktop_env import DesktopEnv
|
||||||
|
from mm_agents.qwen25vl_agent import Qwen25VLAgent
|
||||||
|
|
||||||
|
# import wandb
|
||||||
|
|
||||||
|
|
||||||
|
# Logger Configs {{{ #
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
datetime_str: str = datetime.datetime.now().strftime("%Y%m%d@%H%M%S")
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "normal-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
debug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "debug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
sdebug_handler = logging.FileHandler(
|
||||||
|
os.path.join("logs", "sdebug-{:}.log".format(datetime_str)), encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_handler.setLevel(logging.INFO)
|
||||||
|
debug_handler.setLevel(logging.DEBUG)
|
||||||
|
stdout_handler.setLevel(logging.INFO)
|
||||||
|
sdebug_handler.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="\x1b[1;33m[%(asctime)s \x1b[31m%(levelname)s \x1b[32m%(module)s/%(lineno)d-%(processName)s\x1b[1;33m] \x1b[0m%(message)s"
|
||||||
|
)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
debug_handler.setFormatter(formatter)
|
||||||
|
stdout_handler.setFormatter(formatter)
|
||||||
|
sdebug_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
stdout_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
sdebug_handler.addFilter(logging.Filter("desktopenv"))
|
||||||
|
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(debug_handler)
|
||||||
|
logger.addHandler(stdout_handler)
|
||||||
|
logger.addHandler(sdebug_handler)
|
||||||
|
# }}} Logger Configs #
|
||||||
|
|
||||||
|
logger = logging.getLogger("desktopenv.experiment")
|
||||||
|
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
|
||||||
|
# environment config
|
||||||
|
parser.add_argument("--path_to_vm", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--headless", action="store_true", help="Run in headless machine"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_space", type=str, default="pyautogui", help="Action type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
default="screenshot",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument("--screen_width", type=int, default=1920)
|
||||||
|
parser.add_argument("--screen_height", type=int, default=1080)
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=2.0)
|
||||||
|
parser.add_argument("--max_steps", type=int, default=20)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_config_base_dir", type=str, default="evaluation_examples"
|
||||||
|
)
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--planner_model", type=str, default=None)
|
||||||
|
parser.add_argument("--executor_model", type=str, default="aguvis-s1-s2-agentnet0105-mo5")
|
||||||
|
parser.add_argument("--temperature", type=float, default=0)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=1500)
|
||||||
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
|
||||||
|
# example config
|
||||||
|
parser.add_argument("--domain", type=str, default="all")
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_all_meta_path", type=str, default="evaluation_examples/test_all.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="./results")
|
||||||
|
parser.add_argument("--num_envs", type=int, default=1, help="Number of environments to run in parallel")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_tasks(test_all_meta: dict, num_envs: int) -> List[Dict]:
|
||||||
|
"""Distribute tasks evenly across environments."""
|
||||||
|
# Flatten the tasks into a single list
|
||||||
|
all_tasks = []
|
||||||
|
for domain, examples in test_all_meta.items():
|
||||||
|
for example_id in examples:
|
||||||
|
all_tasks.append((domain, example_id))
|
||||||
|
|
||||||
|
# Calculate tasks per environment
|
||||||
|
tasks_per_env = math.ceil(len(all_tasks) / num_envs)
|
||||||
|
|
||||||
|
# Distribute tasks
|
||||||
|
distributed_tasks = []
|
||||||
|
for i in range(num_envs):
|
||||||
|
env_tasks = {}
|
||||||
|
start_idx = i * tasks_per_env
|
||||||
|
end_idx = min((i + 1) * tasks_per_env, len(all_tasks))
|
||||||
|
|
||||||
|
for domain, example_id in all_tasks[start_idx:end_idx]:
|
||||||
|
if domain not in env_tasks:
|
||||||
|
env_tasks[domain] = []
|
||||||
|
env_tasks[domain].append(example_id)
|
||||||
|
|
||||||
|
distributed_tasks.append(env_tasks)
|
||||||
|
|
||||||
|
return distributed_tasks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def run_env_tasks(env_idx: int, env: DesktopEnv, agent, env_tasks: dict, args: argparse.Namespace, shared_scores: list):
|
||||||
|
"""Run tasks for a single environment."""
|
||||||
|
logger.info(f"Executing tasks in environment {env_idx + 1}/{args.num_envs}")
|
||||||
|
|
||||||
|
for domain in tqdm(env_tasks, desc=f"Env{env_idx+1}-Domain"):
|
||||||
|
for example_id in tqdm(env_tasks[domain], desc="Example", leave=False):
|
||||||
|
config_file = os.path.join(
|
||||||
|
args.test_config_base_dir, f"examples/{domain}/{example_id}.json"
|
||||||
|
)
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
example = json.load(f)
|
||||||
|
|
||||||
|
logger.info(f"[Env {env_idx+1}][Domain]: {domain}")
|
||||||
|
logger.info(f"[Env {env_idx+1}][Example ID]: {example_id}")
|
||||||
|
logger.info(f"[Env {env_idx+1}][Instruction]: {example['instruction']}")
|
||||||
|
|
||||||
|
example_result_dir = os.path.join(
|
||||||
|
args.result_dir,
|
||||||
|
args.action_space,
|
||||||
|
args.observation_type,
|
||||||
|
"planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model),
|
||||||
|
domain,
|
||||||
|
example_id,
|
||||||
|
)
|
||||||
|
os.makedirs(example_result_dir, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
lib_run_single.run_single_example(
|
||||||
|
agent,
|
||||||
|
env,
|
||||||
|
example,
|
||||||
|
args.max_steps,
|
||||||
|
example["instruction"],
|
||||||
|
args,
|
||||||
|
example_result_dir,
|
||||||
|
shared_scores,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Exception in Env{env_idx+1} {domain}/{example_id}: {e}")
|
||||||
|
env.controller.end_recording(
|
||||||
|
os.path.join(example_result_dir, "recording.mp4")
|
||||||
|
)
|
||||||
|
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||||
|
f.write(
|
||||||
|
json.dumps(
|
||||||
|
{"Error": f"Time limit exceeded in {domain}/{example_id}"}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test(args: argparse.Namespace, test_all_meta: dict) -> None:
|
||||||
|
logger.info("Args: %s", args)
|
||||||
|
|
||||||
|
distributed_tasks = distribute_tasks(test_all_meta, args.num_envs)
|
||||||
|
|
||||||
|
# First, set up all environments
|
||||||
|
logger.info("Setting up all environments...")
|
||||||
|
envs = []
|
||||||
|
agents = []
|
||||||
|
|
||||||
|
for env_idx in range(args.num_envs):
|
||||||
|
logger.info(f"Setting up environment {env_idx + 1}/{args.num_envs}")
|
||||||
|
|
||||||
|
agent = Qwen25VLAgent(
|
||||||
|
planner_model=args.planner_model,
|
||||||
|
executor_model=args.executor_model,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
top_p=args.top_p,
|
||||||
|
temperature=args.temperature,
|
||||||
|
action_space=args.action_space,
|
||||||
|
)
|
||||||
|
agents.append(agent)
|
||||||
|
|
||||||
|
env = DesktopEnv(
|
||||||
|
path_to_vm=args.path_to_vm,
|
||||||
|
action_space=agent.action_space,
|
||||||
|
screen_size=(args.screen_width, args.screen_height),
|
||||||
|
headless=args.headless,
|
||||||
|
os_type="Ubuntu",
|
||||||
|
require_a11y_tree=args.observation_type
|
||||||
|
in ["a11y_tree", "screenshot_a11y_tree", "som"],
|
||||||
|
provider_name="docker"
|
||||||
|
)
|
||||||
|
envs.append(env)
|
||||||
|
|
||||||
|
logger.info("All environments are ready. Starting parallel task execution...")
|
||||||
|
|
||||||
|
# Create a shared list for scores across processes
|
||||||
|
with Manager() as manager:
|
||||||
|
shared_scores = manager.list()
|
||||||
|
|
||||||
|
# Create and start processes for each environment
|
||||||
|
processes = []
|
||||||
|
for env_idx, (env, agent, env_tasks) in enumerate(zip(envs, agents, distributed_tasks)):
|
||||||
|
p = Process(
|
||||||
|
target=run_env_tasks,
|
||||||
|
args=(env_idx, env, agent, env_tasks, args, shared_scores)
|
||||||
|
)
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
# Wait for all processes to complete
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
# Convert shared list to regular list
|
||||||
|
scores = list(shared_scores)
|
||||||
|
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores) if scores else 0}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(
|
||||||
|
action_space, use_model, observation_type, result_dir, total_file_json
|
||||||
|
):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
finished = {}
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
finished[domain] = []
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
if example_id == "onboard":
|
||||||
|
continue
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" not in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
for file in os.listdir(example_path):
|
||||||
|
os.remove(os.path.join(example_path, file))
|
||||||
|
else:
|
||||||
|
finished[domain].append(example_id)
|
||||||
|
|
||||||
|
if not finished:
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
for domain, examples in finished.items():
|
||||||
|
if domain in total_file_json:
|
||||||
|
total_file_json[domain] = [
|
||||||
|
x for x in total_file_json[domain] if x not in examples
|
||||||
|
]
|
||||||
|
|
||||||
|
return total_file_json
|
||||||
|
|
||||||
|
|
||||||
|
def get_result(action_space, use_model, observation_type, result_dir, total_file_json):
|
||||||
|
target_dir = os.path.join(result_dir, action_space, observation_type, use_model)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
all_result = []
|
||||||
|
|
||||||
|
for domain in os.listdir(target_dir):
|
||||||
|
domain_path = os.path.join(target_dir, domain)
|
||||||
|
if os.path.isdir(domain_path):
|
||||||
|
for example_id in os.listdir(domain_path):
|
||||||
|
example_path = os.path.join(domain_path, example_id)
|
||||||
|
if os.path.isdir(example_path):
|
||||||
|
if "result.txt" in os.listdir(example_path):
|
||||||
|
# empty all files under example_id
|
||||||
|
try:
|
||||||
|
all_result.append(
|
||||||
|
float(
|
||||||
|
open(
|
||||||
|
os.path.join(example_path, "result.txt"), "r"
|
||||||
|
).read()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
all_result.append(0.0)
|
||||||
|
|
||||||
|
if not all_result:
|
||||||
|
print("New experiment, no result yet.")
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
print("Current Success Rate:", sum(all_result) / len(all_result) * 100, "%")
|
||||||
|
return all_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
####### The complete version of the list of examples #######
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
args = config()
|
||||||
|
|
||||||
|
with open(args.test_all_meta_path, "r", encoding="utf-8") as f:
|
||||||
|
test_all_meta = json.load(f)
|
||||||
|
|
||||||
|
if args.domain != "all":
|
||||||
|
test_all_meta = {args.domain: test_all_meta[args.domain]}
|
||||||
|
|
||||||
|
exp_name = "planner-" + str(args.planner_model) + "-executor-" + str(args.executor_model)
|
||||||
|
|
||||||
|
test_file_list = get_unfinished(
|
||||||
|
args.action_space,
|
||||||
|
exp_name,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
left_info = ""
|
||||||
|
for domain in test_file_list:
|
||||||
|
left_info += f"{domain}: {len(test_file_list[domain])}\n"
|
||||||
|
logger.info(f"Left tasks:\n{left_info}")
|
||||||
|
|
||||||
|
get_result(
|
||||||
|
args.action_space,
|
||||||
|
exp_name,
|
||||||
|
args.observation_type,
|
||||||
|
args.result_dir,
|
||||||
|
test_all_meta,
|
||||||
|
)
|
||||||
|
test(args, test_file_list)
|
||||||
Reference in New Issue
Block a user