add_os_symphony (#399)
This commit is contained in:
0
mm_agents/os_symphony/utils/__init__.py
Executable file
0
mm_agents/os_symphony/utils/__init__.py
Executable file
448
mm_agents/os_symphony/utils/common_utils.py
Executable file
448
mm_agents/os_symphony/utils/common_utils.py
Executable file
@@ -0,0 +1,448 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Tuple, Dict, List, Union
|
||||
import io
|
||||
import os
|
||||
from PIL import Image, ImageDraw
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.process_context import get_current_result_dir
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
def create_pyautogui_code(agent, code: str, obs: Dict) -> Tuple[str, dict | None]:
|
||||
"""
|
||||
Attempts to evaluate the code into a pyautogui code snippet with grounded actions using the observation screenshot.
|
||||
|
||||
Args:
|
||||
agent (ACI): The grounding agent to use for evaluation.
|
||||
code (str): The code string to evaluate.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
|
||||
Returns:
|
||||
exec_code (str): The pyautogui code to execute the grounded action.
|
||||
coordinate (List): The coordinate of the action, a list such as [x1, y1, x2, y2, x3, y3...]. Because may appear more than one coordinate in one action.
|
||||
Modified by Yang.
|
||||
Raises:
|
||||
Exception: If there is an error in evaluating the code.
|
||||
"""
|
||||
agent.assign_screenshot(obs) # Necessary for grounding
|
||||
response = eval(code)
|
||||
if isinstance(response, Tuple):
|
||||
return response
|
||||
elif isinstance(response, str):
|
||||
return response, None
|
||||
else:
|
||||
return "", None
|
||||
|
||||
|
||||
def draw_coordinates(image_bytes: bytes, coordinates: List[Union[int, float]], save_path: str):
|
||||
"""
|
||||
Draw coordinates on the given image and save it to a new file.
|
||||
|
||||
This function receives an image as a byte stream, a list of coordinates in the format [x1, y1, x2, y2, ...],
|
||||
and draws a red 'X' at each (x, y) coordinate point. The resulting image is then saved to the specified path.
|
||||
|
||||
Args:
|
||||
- image_bytes (bytes): The raw byte data of the image (e.g., read from a PNG or JPEG file).
|
||||
- coordinates (List[Union[int, float]]): A flattened list of coordinates, must contain an even number of elements. For example: [x1, y1, x2, y2].
|
||||
- save_path (str): The path where the new image with markings will be saved.
|
||||
"""
|
||||
try:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image = image.convert("RGB")
|
||||
except Exception as e:
|
||||
return
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
cross_size = 15
|
||||
cross_color = "red"
|
||||
cross_width = 3
|
||||
|
||||
for i in range(0, len(coordinates) - 1, 2):
|
||||
x, y = coordinates[i], coordinates[i+1]
|
||||
|
||||
line1_start = (x - cross_size, y - cross_size)
|
||||
line1_end = (x + cross_size, y + cross_size)
|
||||
|
||||
line2_start = (x + cross_size, y - cross_size)
|
||||
line2_end = (x - cross_size, y + cross_size)
|
||||
|
||||
draw.line([line1_start, line1_end], fill=cross_color, width=cross_width)
|
||||
draw.line([line2_start, line2_end], fill=cross_color, width=cross_width)
|
||||
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
image.save(save_path)
|
||||
|
||||
|
||||
def parse_action_from_string(string):
|
||||
'''
|
||||
Parse all strings following "(next action)", including the phrase "next action" itself. If parsing is not possible, return everything.
|
||||
'''
|
||||
marker = "(Next Action)"
|
||||
|
||||
start_index = string.find(marker)
|
||||
|
||||
if start_index != -1:
|
||||
return string[start_index:]
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def call_llm_safe(
|
||||
agent, temperature: float = 0.0, use_thinking: bool = False, **kwargs
|
||||
) -> str:
|
||||
|
||||
try:
|
||||
example_result_dir = get_current_result_dir()
|
||||
except Exception:
|
||||
example_result_dir = "logs/tokens"
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = agent.get_response(
|
||||
temperature=temperature, use_thinking=use_thinking, **kwargs
|
||||
)
|
||||
assert response is not None, "Response from agent should not be None"
|
||||
# print("Response success!")
|
||||
break # If successful, break out of the loop
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"{agent.engine} Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
# record token cost
|
||||
if isinstance(response, tuple):
|
||||
response, usage = response
|
||||
agent_name = agent.agent_name
|
||||
with open(os.path.join(example_result_dir, "token.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"agent_name": agent_name,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"total_tokens": usage.total_tokens
|
||||
}))
|
||||
f.write("\n")
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def call_func_safe(
|
||||
func, **kwargs
|
||||
) -> str:
|
||||
# Retry if fails
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
response = func(**kwargs)
|
||||
break
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
print(f"Attempt {attempt} failed: {e}")
|
||||
if attempt == max_retries:
|
||||
print("Max retries reached. Handling failure.")
|
||||
time.sleep(1.0)
|
||||
|
||||
return response if response is not None else ""
|
||||
|
||||
|
||||
def extract_coords_from_action_dict(action_dict: Dict | None) -> List:
|
||||
coords = []
|
||||
coords_num = 0
|
||||
if action_dict:
|
||||
for k, v in action_dict["args"].items():
|
||||
if (k == "x" and v) or (k == "y" and v) or (k == "x1" and v) or (k == "x2" and v) or (k == "y1" and v) or (k == "y2" and v):
|
||||
coords_num += 1
|
||||
if coords_num == 2:
|
||||
coords.append(action_dict["args"]["x"])
|
||||
coords.append(action_dict["args"]["y"])
|
||||
if coords_num == 4:
|
||||
coords.append(action_dict["args"]["x1"])
|
||||
coords.append(action_dict["args"]["y1"])
|
||||
coords.append(action_dict["args"]["x2"])
|
||||
coords.append(action_dict["args"]["y2"])
|
||||
return coords
|
||||
|
||||
|
||||
def call_llm_formatted(generator, format_checkers, **kwargs):
|
||||
"""
|
||||
Calls the generator agent's LLM and ensures correct formatting.
|
||||
|
||||
Args:
|
||||
generator (ACI): The generator agent to call.
|
||||
obs (Dict): The current observation containing the screenshot.
|
||||
format_checkers (Callable): Functions that take the response and return a tuple of (success, feedback).
|
||||
**kwargs: Additional keyword arguments for the LLM call.
|
||||
|
||||
Returns:
|
||||
response (str): The formatted response from the generator agent.
|
||||
"""
|
||||
max_retries = 3 # Set the maximum number of retries
|
||||
attempt = 0
|
||||
response = ""
|
||||
if kwargs.get("messages") is None:
|
||||
messages = (
|
||||
generator.messages.copy()
|
||||
) # Copy messages to avoid modifying the original
|
||||
else:
|
||||
messages = kwargs["messages"]
|
||||
del kwargs["messages"] # Remove messages from kwargs to avoid passing it twice
|
||||
while attempt < max_retries:
|
||||
response = call_llm_safe(generator, messages=messages, **kwargs)
|
||||
# Prepare feedback messages for incorrect formatting
|
||||
feedback_msgs = []
|
||||
for format_checker in format_checkers:
|
||||
success, feedback = format_checker(response)
|
||||
if not success:
|
||||
feedback_msgs.append(feedback)
|
||||
if not feedback_msgs:
|
||||
# logger.info(f"Response formatted correctly on attempt {attempt} for {generator.engine.model}")
|
||||
break
|
||||
logger.error(
|
||||
f"Response formatting error on attempt {attempt} for {generator.engine.model}. Response: {response} {', '.join(feedback_msgs)}"
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": response}],
|
||||
}
|
||||
)
|
||||
logger.info(f"Bad response: {response}")
|
||||
delimiter = "\n- "
|
||||
formatting_feedback = f"- {delimiter.join(feedback_msgs)}"
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": PROCEDURAL_MEMORY.FORMATTING_FEEDBACK_PROMPT.replace(
|
||||
"FORMATTING_FEEDBACK", formatting_feedback
|
||||
),
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
logger.info("Feedback:\n%s", formatting_feedback)
|
||||
|
||||
attempt += 1
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
"Max retries reached when formatting response. Handling failure."
|
||||
)
|
||||
time.sleep(1.0)
|
||||
return response
|
||||
|
||||
|
||||
def split_thinking_response(full_response: str) -> Tuple[str, str]:
|
||||
try:
|
||||
# Extract thoughts section
|
||||
thoughts = full_response.split("<thoughts>")[-1].split("</thoughts>")[0].strip()
|
||||
|
||||
# Extract answer section
|
||||
answer = full_response.split("<answer>")[-1].split("</answer>")[0].strip()
|
||||
|
||||
return answer, thoughts
|
||||
except Exception as e:
|
||||
return full_response, ""
|
||||
|
||||
|
||||
def parse_code_from_string(input_string):
|
||||
"""Parses a string to extract each line of code enclosed in triple backticks (```)
|
||||
|
||||
Args:
|
||||
input_string (str): The input string containing code snippets.
|
||||
|
||||
Returns:
|
||||
str: The last code snippet found in the input string, or an empty string if no code is found.
|
||||
"""
|
||||
input_string = input_string.strip()
|
||||
|
||||
# This regular expression will match both ```code``` and ```python code```
|
||||
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
||||
pattern = r"```(?:\w+\s+)?(.*?)```"
|
||||
# print(f'[parse_code_from_string].input_string: {input_string}')
|
||||
# Find all non-overlapping matches in the string
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
if len(matches) == 0:
|
||||
# return []
|
||||
return ""
|
||||
relevant_code = matches[
|
||||
-1
|
||||
] # We only care about the last match given it is the grounded action
|
||||
# print(f'[parse_code_from_string].relevant_code: {relevant_code}')
|
||||
return relevant_code
|
||||
|
||||
|
||||
def extract_agent_functions(code):
|
||||
"""
|
||||
Extracts all agent function names from the given code.
|
||||
|
||||
Args:
|
||||
code (str): The code string to search.
|
||||
|
||||
Returns:
|
||||
list: A list of strings like ['agent.click', 'agent.type'].
|
||||
"""
|
||||
pattern = r"agent\.\w+"
|
||||
|
||||
return re.findall(pattern, code)
|
||||
|
||||
|
||||
def compress_image(image_bytes: bytes = None, image: Image = None) -> bytes:
|
||||
"""Compresses an image represented as bytes.
|
||||
|
||||
Compression involves resizing image into half its original size and saving to webp format.
|
||||
|
||||
Args:
|
||||
image_bytes (bytes): The image data to compress.
|
||||
|
||||
Returns:
|
||||
bytes: The compressed image data.
|
||||
"""
|
||||
if not image:
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
output = BytesIO()
|
||||
image.save(output, format="WEBP")
|
||||
compressed_image_bytes = output.getvalue()
|
||||
return compressed_image_bytes
|
||||
|
||||
import math
|
||||
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = IMAGE_FACTOR,
|
||||
min_pixels: int = MIN_PIXELS,
|
||||
max_pixels: int = MAX_PIXELS,
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
min_pixels = MIN_PIXELS if not min_pixels else min_pixels
|
||||
max_pixels = MAX_PIXELS if not max_pixels else max_pixels
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def enhance_observation(image_data: bytes, coordinates: List, expansion_pixels: int = 400, draw=True) -> Tuple[bytes, int, int, int, int]:
|
||||
"""
|
||||
According to the given coordinates, draw markers on the screenshot and crop a "focused" area.
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, int, int, int, int]:
|
||||
- new_image_data (bytes): Data of the cropped image
|
||||
- crop_left (int): X-axis offset
|
||||
- crop_top (int): Y-axis offset
|
||||
- new_width (int): Width of the cropped image
|
||||
- new_height (int): Height of the cropped image
|
||||
"""
|
||||
image = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
||||
draw_ctx = ImageDraw.Draw(image)
|
||||
|
||||
img_width, img_height = image.size
|
||||
|
||||
X_MARKER_SIZE = 40
|
||||
X_MARKER_WIDTH = 5
|
||||
|
||||
def _draw_x(draw_context, center_x, center_y, size=X_MARKER_SIZE, color="red", width=X_MARKER_WIDTH):
|
||||
half_size = size // 2
|
||||
draw_context.line((center_x - half_size, center_y - half_size, center_x + half_size, center_y + half_size), fill=color, width=width)
|
||||
draw_context.line((center_x - half_size, center_y + half_size, center_x + half_size, center_y - half_size), fill=color, width=width)
|
||||
|
||||
crop_left, crop_top, crop_right, crop_bottom = 0, 0, img_width, img_height
|
||||
|
||||
if len(coordinates) == 2:
|
||||
x, y = coordinates[0], coordinates[1]
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x, y)
|
||||
|
||||
crop_left = x - expansion_pixels
|
||||
crop_top = y - expansion_pixels
|
||||
crop_right = x + expansion_pixels
|
||||
crop_bottom = y + expansion_pixels
|
||||
|
||||
elif len(coordinates) >= 4:
|
||||
x1, y1 = coordinates[0], coordinates[1]
|
||||
x2, y2 = coordinates[2], coordinates[3]
|
||||
|
||||
if draw:
|
||||
_draw_x(draw_ctx, x1, y1, color="red")
|
||||
_draw_x(draw_ctx, x2, y2, color="blue")
|
||||
draw_ctx.line((x1, y1, x2, y2), fill="green", width=5)
|
||||
|
||||
box_left = min(x1, x2)
|
||||
box_top = min(y1, y2)
|
||||
box_right = max(x1, x2)
|
||||
box_bottom = max(y1, y2)
|
||||
|
||||
crop_left = box_left - expansion_pixels
|
||||
crop_top = box_top - expansion_pixels
|
||||
crop_right = box_right + expansion_pixels
|
||||
crop_bottom = box_bottom + expansion_pixels
|
||||
|
||||
# check boundary
|
||||
crop_left = max(0, int(crop_left))
|
||||
crop_top = max(0, int(crop_top))
|
||||
crop_right = min(img_width, int(crop_right))
|
||||
crop_bottom = min(img_height, int(crop_bottom))
|
||||
|
||||
crop_box = (crop_left, crop_top, crop_right, crop_bottom)
|
||||
cropped_image = image.crop(crop_box)
|
||||
|
||||
new_width, new_height = cropped_image.size
|
||||
|
||||
buffered = io.BytesIO()
|
||||
cropped_image.save(buffered, format="PNG")
|
||||
|
||||
return buffered.getvalue(), crop_left, crop_top, new_width, new_height
|
||||
|
||||
106
mm_agents/os_symphony/utils/formatters.py
Executable file
106
mm_agents/os_symphony/utils/formatters.py
Executable file
@@ -0,0 +1,106 @@
|
||||
"""This file contains various formatting checks used to reprompt an agent for correctly formatted responses."""
|
||||
from typing import List
|
||||
import json
|
||||
import yaml
|
||||
import re
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
extract_agent_functions,
|
||||
parse_code_from_string,
|
||||
split_thinking_response,
|
||||
)
|
||||
|
||||
|
||||
single_action_check = (
|
||||
lambda response: len(extract_agent_functions(parse_code_from_string(response))) == 1
|
||||
)
|
||||
single_action_error_msg = (
|
||||
"Incorrect code: There must be a single agent action in the code response."
|
||||
)
|
||||
SINGLE_ACTION_FORMATTER = lambda response: (
|
||||
single_action_check(response),
|
||||
single_action_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def code_valid_check(tool_config, response):
|
||||
code = parse_code_from_string(response)
|
||||
print(f'[code_valid_check] parsed code is: {code}')
|
||||
|
||||
# check if the action is pre-defined
|
||||
with open(tool_config, 'r', encoding='utf-8') as f:
|
||||
config = yaml.safe_load(f)
|
||||
valid_methods = set(config['tools'].keys())
|
||||
|
||||
pattern = r"^agent\.(\w+)\(.*\)$"
|
||||
|
||||
match = re.match(pattern, code.strip(), re.DOTALL)
|
||||
|
||||
if match:
|
||||
method_name = match.group(1)
|
||||
print(f'[code_valid_check]: method is {method_name}')
|
||||
if method_name in valid_methods:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
code_valid_error_msg = "Incorrect code: The agent action must be a SINGLE and VALID function and use valid parameters from the docstring list."
|
||||
CODE_VALID_FORMATTER = lambda tool_config, response: (
|
||||
code_valid_check(tool_config, response),
|
||||
code_valid_error_msg,
|
||||
)
|
||||
|
||||
thoughts_answer_tag_check = lambda response: split_thinking_response(response)[1] != ""
|
||||
thoughts_answer_tag_error_msg = "Incorrect response: The response must contain both <thoughts>...</thoughts> and <answer>...</answer> tags."
|
||||
THOUGHTS_ANSWER_TAG_FORMATTER = lambda response: (
|
||||
thoughts_answer_tag_check(response),
|
||||
thoughts_answer_tag_error_msg,
|
||||
)
|
||||
|
||||
integer_answer_check = (
|
||||
lambda response: split_thinking_response(response)[0].strip().isdigit()
|
||||
)
|
||||
integer_answer_error_msg = (
|
||||
"Incorrect response: The <answer>...</answer> tag must contain a single integer."
|
||||
)
|
||||
INTEGER_ANSWER_FORMATTER = lambda response: (
|
||||
integer_answer_check(response),
|
||||
integer_answer_error_msg,
|
||||
)
|
||||
|
||||
|
||||
def json_answer_check(response: str, required_fields: List[str]) -> bool:
|
||||
"""
|
||||
一个只返回 True/False 的检查函数。
|
||||
"""
|
||||
try:
|
||||
answer_str = parse_code_from_string(response)
|
||||
|
||||
if len(answer_str) == 0:
|
||||
return False
|
||||
|
||||
data = json.loads(answer_str)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
|
||||
if set(required_fields) - set(data.keys()):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
json_answer_error_msg = (
|
||||
"Incorrect response: The (Answer) part must contain a valid JSON object that includes ALL required keys and need to be wrapped by ```json and ```"
|
||||
)
|
||||
|
||||
|
||||
JSON_ANSWER_FORMATTER = lambda response, required_fields: (
|
||||
json_answer_check(required_fields, response),
|
||||
json_answer_error_msg,
|
||||
)
|
||||
216
mm_agents/os_symphony/utils/loop_detection.py
Executable file
216
mm_agents/os_symphony/utils/loop_detection.py
Executable file
@@ -0,0 +1,216 @@
|
||||
import io
|
||||
import math
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Dict, Any, Optional
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from rapidfuzz import fuzz
|
||||
import logging
|
||||
from mm_agents.os_symphony.agents.memoryer_agent import StepBehavior
|
||||
|
||||
logger = logging.getLogger("desktopenv.loop_detection")
|
||||
|
||||
def _are_actions_similar(
|
||||
action1: Dict[str, Any],
|
||||
action2: Dict[str, Any],
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] Determine if two actions are similar based on detailed rules.
|
||||
|
||||
Args:
|
||||
action1: The first action.
|
||||
action2: The second action.
|
||||
image_width: The width of the screenshot.
|
||||
image_height: The height of the screenshot.
|
||||
relative_coord_threshold: A relative distance threshold for coordinate comparison.
|
||||
fuzzy_text_threshold: A similarity threshold (0-100) for fuzzy text matching.
|
||||
|
||||
Returns:
|
||||
Return True if the actions are similar, otherwise return False.
|
||||
"""
|
||||
# ensure same action
|
||||
if action1.get("function") != action2.get("function"):
|
||||
return False
|
||||
|
||||
func = action1.get("function")
|
||||
args1 = action1.get("args", {})
|
||||
args2 = action2.get("args", {})
|
||||
|
||||
diagonal = math.sqrt(image_width**2 + image_height**2)
|
||||
abs_coord_thresh = relative_coord_threshold * diagonal
|
||||
|
||||
def are_coords_close(x1, y1, x2, y2):
|
||||
if None in [x1, y1, x2, y2]: return False
|
||||
distance = math.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
||||
return distance < abs_coord_thresh
|
||||
|
||||
if func == "click":
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("button") == args2.get("button") and
|
||||
args1.get("clicks") == args2.get("clicks")
|
||||
)
|
||||
|
||||
elif func == "open":
|
||||
return args1.get("name") == args2.get("name")
|
||||
|
||||
elif func == "type":
|
||||
if args1.get("x") and args1.get("y") and args2.get("x") and args2.get("y"):
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
args1.get("text") == args2.get("text")
|
||||
)
|
||||
else:
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "drag":
|
||||
return (
|
||||
are_coords_close(args1.get("x1"), args1.get("y1"), args2.get("x1"), args2.get("y1")) and
|
||||
are_coords_close(args1.get("x2"), args1.get("y2"), args2.get("x2"), args2.get("y2"))
|
||||
)
|
||||
|
||||
elif func == "set_cell_values":
|
||||
return args1.get("text") == args2.get("text")
|
||||
|
||||
elif func == "scroll":
|
||||
clicks1 = args1.get("clicks", 0)
|
||||
clicks2 = args2.get("clicks", 0)
|
||||
if (clicks1 == 0 and clicks2 != 0) or (clicks1 != 0 and clicks2 == 0):
|
||||
same_direction = False
|
||||
else:
|
||||
same_direction = math.copysign(1, clicks1) == math.copysign(1, clicks2)
|
||||
|
||||
return (
|
||||
are_coords_close(args1.get("x"), args1.get("y"), args2.get("x"), args2.get("y")) and
|
||||
same_direction and
|
||||
args1.get("shift") == args2.get("shift")
|
||||
)
|
||||
|
||||
elif func == "key":
|
||||
return args1.get("keys") == args2.get("keys")
|
||||
|
||||
elif func == "wait":
|
||||
return True
|
||||
|
||||
elif func in ["call_code_agent", "call_search_agent"]:
|
||||
query1 = args1.get("query", "")
|
||||
query2 = args2.get("query", "")
|
||||
# use Levenshtein distance to calculate fuzzy similarity
|
||||
query_similarity = fuzz.token_set_ratio(query1, query2)
|
||||
# print(f'query_sim: {query_similarity}')
|
||||
return (
|
||||
query_similarity >= fuzzy_text_threshold and
|
||||
args1.get("result") == args2.get("result")
|
||||
)
|
||||
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _are_steps_similar_optimized(
|
||||
step1: StepBehavior,
|
||||
step2: StepBehavior,
|
||||
idx1: int,
|
||||
idx2: int,
|
||||
full_trajectory: List[StepBehavior],
|
||||
phash_threshold: int,
|
||||
ssim_threshold: float,
|
||||
# 动作比较所需的参数
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
relative_coord_threshold: float,
|
||||
fuzzy_text_threshold: float,
|
||||
) -> bool:
|
||||
"""
|
||||
[Internal Auxiliary] use pre-calculated data to quickly determine if the two actions are similar/
|
||||
"""
|
||||
|
||||
if step1.phash is None or step2.phash is None:
|
||||
return False
|
||||
|
||||
if (step1.phash - step2.phash) > phash_threshold:
|
||||
return False
|
||||
|
||||
|
||||
later_step_idx = max(idx1, idx2)
|
||||
earlier_step_idx = min(idx1, idx2)
|
||||
|
||||
ssim_score = full_trajectory[later_step_idx].ssim_list[earlier_step_idx]
|
||||
|
||||
if ssim_score < ssim_threshold:
|
||||
return False
|
||||
|
||||
if not _are_actions_similar(
|
||||
step1.action_dict, step2.action_dict,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def detect_loop(
|
||||
full_trajectory: List[StepBehavior],
|
||||
image_width: int = 1920,
|
||||
image_height: int = 1080,
|
||||
N: int = 3,
|
||||
phash_threshold: int = 1,
|
||||
ssim_threshold: float = 0.99,
|
||||
relative_coord_threshold: float = 0.02,
|
||||
fuzzy_text_threshold: float = 85.0,
|
||||
) -> Tuple[bool, Optional[Dict[str, List[int]]]]:
|
||||
"""
|
||||
Efficiently detect the presence of looping patterns based on precomputed data.
|
||||
|
||||
Args:
|
||||
full_trajectory (List[StepBehavior]): Full history including the current step.
|
||||
image_width (int): Width of the screenshot.
|
||||
image_height (int): Height of the screenshot.
|
||||
N (int): Number of steps in the candidate loop (sequence length).
|
||||
phash_threshold (int): Hamming distance threshold for pHash similarity. Recommended: 0–2.
|
||||
ssim_threshold (float): SSIM similarity threshold for image comparison. Recommended: 0.95–0.99.
|
||||
relative_coord_threshold (float): Relative threshold for coordinate similarity. Recommended: 0.01–0.05.
|
||||
fuzzy_text_threshold (float): Fuzzy text matching similarity threshold (0–100) for agent queries.
|
||||
|
||||
Returns:
|
||||
A tuple (is_loop_detected, loop_info):
|
||||
- is_loop_detected (bool): Whether a loop is detected.
|
||||
- loop_info (Dict | None): If a loop is detected, contains the indices of the two matching sequences.
|
||||
"""
|
||||
L = len(full_trajectory)
|
||||
|
||||
if not isinstance(N, int) or N <= 0 or L < 2 * N:
|
||||
return False, None
|
||||
|
||||
max_start_index = L - 2 * N
|
||||
for i in range(max_start_index, -1, -1):
|
||||
is_potential_match = True
|
||||
|
||||
for j in range(N):
|
||||
idx_prev = i + j
|
||||
idx_curr = (L - N) + j
|
||||
|
||||
step_prev = full_trajectory[idx_prev]
|
||||
step_curr = full_trajectory[idx_curr]
|
||||
|
||||
if not _are_steps_similar_optimized(
|
||||
step_prev, step_curr, idx_prev, idx_curr, full_trajectory,
|
||||
phash_threshold, ssim_threshold,
|
||||
image_width, image_height, relative_coord_threshold, fuzzy_text_threshold
|
||||
):
|
||||
is_potential_match = False
|
||||
break
|
||||
|
||||
if is_potential_match:
|
||||
previous_sequence_indices = list(range(i, i + N))
|
||||
loop_info = {
|
||||
"match_sequence_indices": previous_sequence_indices
|
||||
}
|
||||
return True, loop_info
|
||||
|
||||
return False, None
|
||||
|
||||
30
mm_agents/os_symphony/utils/process_context.py
Executable file
30
mm_agents/os_symphony/utils/process_context.py
Executable file
@@ -0,0 +1,30 @@
|
||||
# process_context.py
|
||||
# This module provides an independent context storage for each process.
|
||||
|
||||
from multiprocessing import current_process
|
||||
|
||||
# We will store process-specific contexts here.
|
||||
# Since each process has its own separate memory space, when accessing this variable,
|
||||
# each process accesses its own copy, without conflicting with others.
|
||||
_context_storage = {}
|
||||
|
||||
def set_context(key, value):
|
||||
"""Set a value in the context of the current process."""
|
||||
_context_storage[key] = value
|
||||
# print(f"[{current_process().name}] Set context: {key} = {value}") # For debugging
|
||||
|
||||
def get_context(key, default=None):
|
||||
"""Retrieve a value from the context of the current process."""
|
||||
value = _context_storage.get(key, default)
|
||||
# print(f"[{current_process().name}] Get context: {key} -> {value}") # For debugging
|
||||
if value is None and default is None:
|
||||
raise NameError(f"'{key}' not found in the current process context. Ensure it is set at the process entry point.")
|
||||
return value
|
||||
|
||||
# For convenience, we can create a specialized getter for result_dir
|
||||
def get_current_result_dir():
|
||||
"""Get the result_dir specific to the current process."""
|
||||
return get_context('current_result_dir')
|
||||
|
||||
def set_current_result_dir(example_result_dir):
|
||||
set_context("current_result_dir", example_result_dir)
|
||||
Reference in New Issue
Block a user