Files
sci-gui-agent-benchmark/mm_agents/os_symphony/agents/searcher_agent.py
2025-12-23 14:30:44 +08:00

479 lines
19 KiB
Python
Executable File

import logging
import urllib.parse
from typing import Any, Dict, List, Optional
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
from mm_agents.os_symphony.utils.common_utils import (
draw_coordinates,
call_llm_formatted,
parse_code_from_string,
create_pyautogui_code
)
from mm_agents.os_symphony.core.mllm import LMMAgent
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
import os
import time
import json
logger = logging.getLogger("desktopenv.searcher_agent")
# Agent action decorator
def searcher_agent_action(func):
func.is_searcher_agent_action = True
return func
# --- Abstract Base Class and Factory ---
class SearcherAgent:
def __init__(self, engine_params: Dict, platform: str):
self.engine_params = engine_params
self.result_dir = ""
self.tutorial_or_hint = ""
self.tutorial_notes = []
self.max_trajectory_length = 8
self.platform = platform
self.budget = engine_params.get("budget", 20)
@staticmethod
def create(engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str="password"):
searcher_type = engine_params.get("type", "vlm")
if searcher_type == "vlm":
return VLMSearcherAgent(engine_params=engine_params, search_env=search_env, grounder_agent=grounder_agent, platform=platform, client_password=client_password)
else:
raise NotImplementedError
def _get_search_time(self) -> int:
"""for the name of result directory"""
if not self.result_dir: return 1
search_times: list[int] = []
try:
if not os.path.exists(self.result_dir): return 1
for item_name in os.listdir(self.result_dir):
full_path = os.path.join(self.result_dir, item_name)
if os.path.isdir(full_path) and item_name.startswith("search_"):
try:
time_val = int(item_name.split('_', 1)[1])
search_times.append(time_val)
except (ValueError, IndexError):
continue
except Exception:
return 1
if not search_times: return 1
return max(search_times) + 1
def search(self, query: str, obs) -> str:
"""
Args:
query: Format like "How to xxxx?", must be a detailed subtask
obs: Current screenshot
"""
raise NotImplementedError("Subclasses must implement the 'search' method")
class VLMSearcherAgent(SearcherAgent):
"""
Start a new, isolated vm, and open chrome in advance
"""
def __init__(self, engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str):
SearcherAgent.__init__(self, engine_params=engine_params, platform=platform)
self.grounder_agent = grounder_agent
self.client_password = client_password
self.env = search_env
self.use_thinking = engine_params.get("model", "") in [
"claude-opus-4-20250514",
"claude-sonnet-4-20250514",
"claude-3-7-sonnet-20250219",
"claude-sonnet-4-5-20250929",
]
self.engine = engine_params.get("engine", "google")
# Reuse OSWorld's initialization script to set up Chrome, then directly perform a Google search using the query—currently, the query can be substituted by a placeholder field.
self.task_config = {
"id": "searcher",
"instruction": "searcher",
"config": [
{
"type": "launch",
"parameters": {
"command": [
"google-chrome",
"--remote-debugging-port=1337"
]
}
},
{
"type": "launch",
"parameters": {
"command": [
"socat",
"tcp-listen:9222,fork",
"tcp:localhost:1337"
]
}
},
{
"type": "chrome_open_tabs",
"parameters": {
"urls_to_open": [
"GOOGLE_SEARCH_URL"
]
}
},
{
"type": "activate_window",
"parameters": {
"window_name": "Google Chrome"
}
}
],
"proxy": True
}
self.obs = None
def reset(self, query):
# When the search function is invoked, a new agent is created; the environment is instantiated only upon the first call, but it must be reset on every invocation.
self.tutorial_notes = []
self.tutorial_or_hint = ""
self.system_prompt = PROCEDURAL_MEMORY.construct_vlm_searcher_procedural_memory(
agent_class=type(self)
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
self.searcher_agent = LMMAgent(
engine_params=self.engine_params,
system_prompt=self.system_prompt
)
self.env.start()
# config URL and initialize search environment (google/duckduckgo)
search_url = f"https://www.google.com/search?q=" + urllib.parse.quote_plus(query) if self.engine == "google" else f"https://www.duckduckgo.com/?q=" + urllib.parse.quote_plus(query)
self.task_config["config"][2]["parameters"]["urls_to_open"][0] = search_url
self.env.reset(task_config=self.task_config)
print("[Searcher] sleeping...")
time.sleep(5)
def flush_messages(self):
"""Flush messages based on the model's context limits.
This method ensures that the agent's message history does not exceed the maximum trajectory length.
Side Effects:
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
"""
engine_type = self.engine_params.get("engine_type", "")
# Flush strategy for long-context models: keep all text, only keep latest images
if engine_type in ["anthropic", "openai", "gemini"]:
max_images = self.max_trajectory_length
for agent in [self.searcher_agent]:
if agent is None:
continue
# keep latest k images
# @Yang: keep the first main agent image
img_count = 0
for i in range(len(agent.messages) - 1, 1, -1):
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
if "image" in agent.messages[i]["content"][j].get("type", ""):
img_count += 1
if img_count > max_images:
del agent.messages[i]["content"][j]
# Flush strategy for non-long-context models: drop full turns
else:
# generator msgs are alternating [user, assistant], so 2 per round
if len(self.searcher_agent.messages) > 2 * self.max_trajectory_length + 1:
self.searcher_agent.messages.pop(1)
self.searcher_agent.messages.pop(1)
def assign_screenshot(self, obs):
self.obs = obs
def search(self, query: str, main_obs):
# only create vm when search is called
self.reset(query=query) # reset
search_result_dir = os.path.join(self.result_dir, f"search_{self._get_search_time()}")
os.makedirs(search_result_dir, exist_ok=True)
obs = self.env._get_obs() # Get the initial observation
step_idx = 0
initial_state_text = (
"This screenshot shows the current visual context of the main GUI Agent you are assisting. "
"Use this image to understand the application, the current view, and the overall environment. "
"Your primary goal is to find a tutorial that is highly relevant and well-aligned with this specific context, "
"ensuring the instructions you find are applicable to what the main agent is currently seeing."
)
self.searcher_agent.add_message(
text_content=initial_state_text,
image_content=main_obs["screenshot"],
role="user"
)
execution_history = []
completion_reason = ""
final_answer = ""
while step_idx < self.budget:
# update system_prompt dynamically
tutorial_notes_str = ""
if len(self.tutorial_notes) > 0:
for i, note in enumerate(self.tutorial_notes, 1):
tutorial_notes_str += f"Tutorial Note {i}: {note}\n\n"
if step_idx == self.budget - 1:
# eager mode
self.system_prompt = PROCEDURAL_MEMORY.construct_searcher_eager_mode_procedural_memory(
agent_class=type(self)
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
system_prompt = self.system_prompt.replace("TUTORIAL_PLACEHOLDER", tutorial_notes_str)
self.searcher_agent.add_system_prompt(system_prompt=system_prompt)
# start a new turn
self.assign_screenshot(obs=obs)
generator_message = ""
self.searcher_agent.add_message(
generator_message, image_content=obs["screenshot"], role="user"
)
format_checkers = []
# predict action
plan = call_llm_formatted(
self.searcher_agent,
format_checkers,
temperature=self.engine_params.get("temperture", 0.1),
use_thinking=self.use_thinking,
)
self.searcher_agent.add_message(plan, role="assistant")
execution_history.append(plan)
logger.info("SEARCHER PLAN:\n %s", plan)
plan_code = parse_code_from_string(plan)
try:
assert plan_code, "Plan code should not be empty"
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
exec_code, coords = create_pyautogui_code(self, plan_code, obs)
except Exception as e:
logger.error(
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
)
exec_code = self.wait(
1.333
) # Skip a turn if the code cannot be evaluated
self.flush_messages()
# execute action
action = exec_code
logger.info("Step %d: %s", step_idx + 1, action)
# Save screenshot and trajectory information
with open(os.path.join(search_result_dir, f"step_{step_idx + 1}.png"),
"wb") as _f:
_f.write(obs['screenshot'])
if coords is not None and isinstance(coords, list):
draw_coordinates(
image_bytes=obs['screenshot'],
coordinates=coords,
save_path=os.path.join(search_result_dir, f"step_{step_idx + 1}_draw.png")
)
with open(os.path.join(search_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps({
"query": query,
"step_num": step_idx + 1,
"action": action,
"response": {
"plan": plan,
"plan_code": plan_code,
"coordinates": coords
},
"screenshot_file": f"step_{step_idx + 1}.png"
}, ensure_ascii=False))
f.write("\n")
with open(os.path.join(search_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
json.dump({
"query": query,
"step_num": step_idx + 1,
"action": action,
"response": {
"plan": plan,
"plan_code": plan_code,
"coordinates": coords
},
"screenshot_file": f"step_{step_idx + 1}.png"
}, f, indent=4, ensure_ascii=False)
if exec_code in ["DONE", "FAIL"]:
# terminate loop
completion_reason = exec_code
final_answer = self.tutorial_or_hint
break
else:
obs, _, _, _ = self.env.step(action, 5)
step_idx += 1
if completion_reason == "":
completion_reason = "BUDGET_EXHAUSTED"
final_answer = "Sorry, can't get the useful tutorial about the GUI task you provided."
return {
"query": query,
"completion_reason": completion_reason,
"tutorial_notes": self.tutorial_notes,
"execution_history": execution_history,
"steps_executed": step_idx,
"budget": self.budget,
"final_answer": final_answer,
}
@searcher_agent_action
def click(
self,
element_description: str,
num_clicks: int = 1,
button_type: str = "left",
):
"""Click on the element
Args:
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
num_clicks:int, number of times to click the element
button_type:str, which mouse button to press can be "left", "middle", or "right"
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
command = "import pyautogui; "
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
# Return pyautoguicode to click on the element
return (command, [x, y])
@searcher_agent_action
def type(
self,
element_description: Optional[str] = None,
text: str = "",
overwrite: bool = True,
enter: bool = False
):
"""Type text/unicode into a specific element
Args:
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
text:str, the text to type
overwrite:bool, Default is True, assign it to False if the text should not overwrite the existing text. Using this argument clears all text in an element.
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
"""
commands = (
"import os;"
"import pyautogui;"
"import pyperclip;"
"import subprocess;"
"import time;"
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
)
click_coords = None
if element_description is not None:
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
click_coords = [x, y]
commands += f"pyautogui.click({x}, {y});"
if overwrite:
commands += (
f"pyautogui.hotkey('ctrl', 'a');"
"pyautogui.press('backspace');"
)
# use paste to input
commands += (
"original_clipboard = pyperclip.paste();"
f"pyperclip.copy({repr(text)});"
"pyautogui.hotkey('ctrl', 'v');"
"pyperclip.copy(original_clipboard);"
)
if enter:
commands += "pyautogui.press('enter');"
if click_coords is not None:
return (commands, click_coords)
else:
return commands
@searcher_agent_action
def scroll(self, element_description: str, clicks: int, shift: bool = False):
"""Scroll the element in the specified direction
Args:
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
shift:bool, whether to use shift+scroll for horizontal scrolling
"""
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
if shift:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", [x, y])
else:
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", [x, y])
@searcher_agent_action
def hotkey(self, keys: List):
"""Press a hotkey combination (can press a single key as well)
Args:
keys: List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
"""
# add quotes around the keys
keys = [f"'{key}'" for key in keys]
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
@searcher_agent_action
def save_to_tutorial_notes(self, text: str):
"""Save high quality and useful information to a long-term knowledge bank for reuse during this search task.
Args:
text:str, the text to save to the tutorial notes
"""
self.tutorial_notes.append(text)
return """WAIT"""
@searcher_agent_action
def wait(self, time: float):
"""Wait for a specified amount of time
Args:
time:float the amount of time to wait in seconds
"""
return f"""import time; time.sleep({time})"""
@searcher_agent_action
def done(
self,
tutorial: str
):
"""End the current task with a success. Use this when you believe the entire task has been fully completed.
Args:
tutorial:str, A detailed, step-by-step tutorial compiled from the search results to be passed to the main agent.
"""
self.tutorial_or_hint = tutorial
return """DONE"""
@searcher_agent_action
def fail(
self,
hint: str
):
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete.
Args:
hint:str, A hint or reason explaining why the search failed, or what kind of information was missing.
"""
self.tutorial_or_hint = hint
return """FAIL"""