add_os_symphony (#399)
This commit is contained in:
478
mm_agents/os_symphony/agents/searcher_agent.py
Executable file
478
mm_agents/os_symphony/agents/searcher_agent.py
Executable file
@@ -0,0 +1,478 @@
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
from mm_agents.os_symphony.memory.procedural_memory import PROCEDURAL_MEMORY
|
||||
from mm_agents.os_symphony.utils.common_utils import (
|
||||
draw_coordinates,
|
||||
call_llm_formatted,
|
||||
parse_code_from_string,
|
||||
create_pyautogui_code
|
||||
)
|
||||
from mm_agents.os_symphony.core.mllm import LMMAgent
|
||||
from mm_agents.os_symphony.agents.grounder_agent import GrounderAgent
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
|
||||
|
||||
logger = logging.getLogger("desktopenv.searcher_agent")
|
||||
|
||||
# Agent action decorator
|
||||
def searcher_agent_action(func):
|
||||
func.is_searcher_agent_action = True
|
||||
return func
|
||||
|
||||
|
||||
# --- Abstract Base Class and Factory ---
|
||||
class SearcherAgent:
|
||||
def __init__(self, engine_params: Dict, platform: str):
|
||||
self.engine_params = engine_params
|
||||
self.result_dir = ""
|
||||
self.tutorial_or_hint = ""
|
||||
self.tutorial_notes = []
|
||||
self.max_trajectory_length = 8
|
||||
self.platform = platform
|
||||
self.budget = engine_params.get("budget", 20)
|
||||
|
||||
@staticmethod
|
||||
def create(engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str="password"):
|
||||
searcher_type = engine_params.get("type", "vlm")
|
||||
if searcher_type == "vlm":
|
||||
return VLMSearcherAgent(engine_params=engine_params, search_env=search_env, grounder_agent=grounder_agent, platform=platform, client_password=client_password)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_search_time(self) -> int:
|
||||
"""for the name of result directory"""
|
||||
if not self.result_dir: return 1
|
||||
search_times: list[int] = []
|
||||
try:
|
||||
if not os.path.exists(self.result_dir): return 1
|
||||
for item_name in os.listdir(self.result_dir):
|
||||
full_path = os.path.join(self.result_dir, item_name)
|
||||
if os.path.isdir(full_path) and item_name.startswith("search_"):
|
||||
try:
|
||||
time_val = int(item_name.split('_', 1)[1])
|
||||
search_times.append(time_val)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
except Exception:
|
||||
return 1
|
||||
if not search_times: return 1
|
||||
return max(search_times) + 1
|
||||
|
||||
def search(self, query: str, obs) -> str:
|
||||
"""
|
||||
Args:
|
||||
query: Format like "How to xxxx?", must be a detailed subtask
|
||||
obs: Current screenshot
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the 'search' method")
|
||||
|
||||
class VLMSearcherAgent(SearcherAgent):
|
||||
"""
|
||||
Start a new, isolated vm, and open chrome in advance
|
||||
"""
|
||||
def __init__(self, engine_params: Dict, search_env, grounder_agent: GrounderAgent, platform: str, client_password: str):
|
||||
SearcherAgent.__init__(self, engine_params=engine_params, platform=platform)
|
||||
|
||||
self.grounder_agent = grounder_agent
|
||||
self.client_password = client_password
|
||||
self.env = search_env
|
||||
|
||||
self.use_thinking = engine_params.get("model", "") in [
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
]
|
||||
|
||||
self.engine = engine_params.get("engine", "google")
|
||||
|
||||
# Reuse OSWorld's initialization script to set up Chrome, then directly perform a Google search using the query—currently, the query can be substituted by a placeholder field.
|
||||
self.task_config = {
|
||||
"id": "searcher",
|
||||
"instruction": "searcher",
|
||||
"config": [
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"google-chrome",
|
||||
"--remote-debugging-port=1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "launch",
|
||||
"parameters": {
|
||||
"command": [
|
||||
"socat",
|
||||
"tcp-listen:9222,fork",
|
||||
"tcp:localhost:1337"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "chrome_open_tabs",
|
||||
"parameters": {
|
||||
"urls_to_open": [
|
||||
"GOOGLE_SEARCH_URL"
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "activate_window",
|
||||
"parameters": {
|
||||
"window_name": "Google Chrome"
|
||||
}
|
||||
}
|
||||
],
|
||||
"proxy": True
|
||||
}
|
||||
self.obs = None
|
||||
|
||||
def reset(self, query):
|
||||
# When the search function is invoked, a new agent is created; the environment is instantiated only upon the first call, but it must be reset on every invocation.
|
||||
self.tutorial_notes = []
|
||||
self.tutorial_or_hint = ""
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_vlm_searcher_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
self.searcher_agent = LMMAgent(
|
||||
engine_params=self.engine_params,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
self.env.start()
|
||||
# config URL and initialize search environment (google/duckduckgo)
|
||||
search_url = f"https://www.google.com/search?q=" + urllib.parse.quote_plus(query) if self.engine == "google" else f"https://www.duckduckgo.com/?q=" + urllib.parse.quote_plus(query)
|
||||
self.task_config["config"][2]["parameters"]["urls_to_open"][0] = search_url
|
||||
|
||||
self.env.reset(task_config=self.task_config)
|
||||
print("[Searcher] sleeping...")
|
||||
time.sleep(5)
|
||||
|
||||
def flush_messages(self):
|
||||
"""Flush messages based on the model's context limits.
|
||||
|
||||
This method ensures that the agent's message history does not exceed the maximum trajectory length.
|
||||
|
||||
Side Effects:
|
||||
- Modifies the messages of generator, reflection, and bon_judge agents to fit within the context limits.
|
||||
"""
|
||||
engine_type = self.engine_params.get("engine_type", "")
|
||||
|
||||
# Flush strategy for long-context models: keep all text, only keep latest images
|
||||
if engine_type in ["anthropic", "openai", "gemini"]:
|
||||
max_images = self.max_trajectory_length
|
||||
for agent in [self.searcher_agent]:
|
||||
if agent is None:
|
||||
continue
|
||||
# keep latest k images
|
||||
# @Yang: keep the first main agent image
|
||||
img_count = 0
|
||||
for i in range(len(agent.messages) - 1, 1, -1):
|
||||
for j in range(len(agent.messages[i]["content"]) - 1, -1, -1):
|
||||
if "image" in agent.messages[i]["content"][j].get("type", ""):
|
||||
img_count += 1
|
||||
if img_count > max_images:
|
||||
del agent.messages[i]["content"][j]
|
||||
|
||||
# Flush strategy for non-long-context models: drop full turns
|
||||
else:
|
||||
# generator msgs are alternating [user, assistant], so 2 per round
|
||||
if len(self.searcher_agent.messages) > 2 * self.max_trajectory_length + 1:
|
||||
self.searcher_agent.messages.pop(1)
|
||||
self.searcher_agent.messages.pop(1)
|
||||
|
||||
def assign_screenshot(self, obs):
|
||||
self.obs = obs
|
||||
|
||||
def search(self, query: str, main_obs):
|
||||
# only create vm when search is called
|
||||
self.reset(query=query) # reset
|
||||
search_result_dir = os.path.join(self.result_dir, f"search_{self._get_search_time()}")
|
||||
os.makedirs(search_result_dir, exist_ok=True)
|
||||
|
||||
obs = self.env._get_obs() # Get the initial observation
|
||||
step_idx = 0
|
||||
initial_state_text = (
|
||||
"This screenshot shows the current visual context of the main GUI Agent you are assisting. "
|
||||
"Use this image to understand the application, the current view, and the overall environment. "
|
||||
"Your primary goal is to find a tutorial that is highly relevant and well-aligned with this specific context, "
|
||||
"ensuring the instructions you find are applicable to what the main agent is currently seeing."
|
||||
)
|
||||
self.searcher_agent.add_message(
|
||||
text_content=initial_state_text,
|
||||
image_content=main_obs["screenshot"],
|
||||
role="user"
|
||||
)
|
||||
execution_history = []
|
||||
completion_reason = ""
|
||||
final_answer = ""
|
||||
|
||||
while step_idx < self.budget:
|
||||
# update system_prompt dynamically
|
||||
tutorial_notes_str = ""
|
||||
if len(self.tutorial_notes) > 0:
|
||||
for i, note in enumerate(self.tutorial_notes, 1):
|
||||
tutorial_notes_str += f"Tutorial Note {i}: {note}\n\n"
|
||||
|
||||
if step_idx == self.budget - 1:
|
||||
# eager mode
|
||||
self.system_prompt = PROCEDURAL_MEMORY.construct_searcher_eager_mode_procedural_memory(
|
||||
agent_class=type(self)
|
||||
).replace("CURRENT_OS", self.platform).replace("QUERY", query)
|
||||
|
||||
system_prompt = self.system_prompt.replace("TUTORIAL_PLACEHOLDER", tutorial_notes_str)
|
||||
self.searcher_agent.add_system_prompt(system_prompt=system_prompt)
|
||||
|
||||
# start a new turn
|
||||
self.assign_screenshot(obs=obs)
|
||||
generator_message = ""
|
||||
|
||||
self.searcher_agent.add_message(
|
||||
generator_message, image_content=obs["screenshot"], role="user"
|
||||
)
|
||||
format_checkers = []
|
||||
|
||||
# predict action
|
||||
plan = call_llm_formatted(
|
||||
self.searcher_agent,
|
||||
format_checkers,
|
||||
temperature=self.engine_params.get("temperture", 0.1),
|
||||
use_thinking=self.use_thinking,
|
||||
)
|
||||
|
||||
self.searcher_agent.add_message(plan, role="assistant")
|
||||
execution_history.append(plan)
|
||||
logger.info("SEARCHER PLAN:\n %s", plan)
|
||||
|
||||
plan_code = parse_code_from_string(plan)
|
||||
try:
|
||||
assert plan_code, "Plan code should not be empty"
|
||||
# exec_code e.g. import pyautogui; pyautogui.click(1, 2);
|
||||
exec_code, coords = create_pyautogui_code(self, plan_code, obs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Could not evaluate the following plan code:\n{plan_code}\nError: {e}"
|
||||
)
|
||||
exec_code = self.wait(
|
||||
1.333
|
||||
) # Skip a turn if the code cannot be evaluated
|
||||
|
||||
self.flush_messages()
|
||||
|
||||
# execute action
|
||||
action = exec_code
|
||||
logger.info("Step %d: %s", step_idx + 1, action)
|
||||
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(search_result_dir, f"step_{step_idx + 1}.png"),
|
||||
"wb") as _f:
|
||||
_f.write(obs['screenshot'])
|
||||
|
||||
if coords is not None and isinstance(coords, list):
|
||||
draw_coordinates(
|
||||
image_bytes=obs['screenshot'],
|
||||
coordinates=coords,
|
||||
save_path=os.path.join(search_result_dir, f"step_{step_idx + 1}_draw.png")
|
||||
)
|
||||
|
||||
with open(os.path.join(search_result_dir, "traj.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
with open(os.path.join(search_result_dir, f"traj_{step_idx+1}.json"), "w", encoding="utf-8") as f:
|
||||
json.dump({
|
||||
"query": query,
|
||||
"step_num": step_idx + 1,
|
||||
"action": action,
|
||||
"response": {
|
||||
"plan": plan,
|
||||
"plan_code": plan_code,
|
||||
"coordinates": coords
|
||||
},
|
||||
"screenshot_file": f"step_{step_idx + 1}.png"
|
||||
}, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if exec_code in ["DONE", "FAIL"]:
|
||||
# terminate loop
|
||||
completion_reason = exec_code
|
||||
final_answer = self.tutorial_or_hint
|
||||
break
|
||||
else:
|
||||
obs, _, _, _ = self.env.step(action, 5)
|
||||
|
||||
step_idx += 1
|
||||
|
||||
if completion_reason == "":
|
||||
completion_reason = "BUDGET_EXHAUSTED"
|
||||
final_answer = "Sorry, can't get the useful tutorial about the GUI task you provided."
|
||||
|
||||
return {
|
||||
"query": query,
|
||||
"completion_reason": completion_reason,
|
||||
"tutorial_notes": self.tutorial_notes,
|
||||
"execution_history": execution_history,
|
||||
"steps_executed": step_idx,
|
||||
"budget": self.budget,
|
||||
"final_answer": final_answer,
|
||||
}
|
||||
|
||||
@searcher_agent_action
|
||||
def click(
|
||||
self,
|
||||
element_description: str,
|
||||
num_clicks: int = 1,
|
||||
button_type: str = "left",
|
||||
):
|
||||
"""Click on the element
|
||||
Args:
|
||||
element_description:str, a detailed descriptions of which element to click on. This description should be at least a full sentence.
|
||||
num_clicks:int, number of times to click the element
|
||||
button_type:str, which mouse button to press can be "left", "middle", or "right"
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
command = "import pyautogui; "
|
||||
command += f"""import pyautogui; pyautogui.click({x}, {y}, clicks={num_clicks}, button={repr(button_type)}); """
|
||||
|
||||
# Return pyautoguicode to click on the element
|
||||
return (command, [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def type(
|
||||
self,
|
||||
element_description: Optional[str] = None,
|
||||
text: str = "",
|
||||
overwrite: bool = True,
|
||||
enter: bool = False
|
||||
):
|
||||
"""Type text/unicode into a specific element
|
||||
Args:
|
||||
element_description:str, a detailed description of which element to enter text in. This description should be at least a full sentence.
|
||||
text:str, the text to type
|
||||
overwrite:bool, Default is True, assign it to False if the text should not overwrite the existing text. Using this argument clears all text in an element.
|
||||
enter:bool, Assign it to True if the enter key should be pressed after typing the text, otherwise assign it to False.
|
||||
"""
|
||||
commands = (
|
||||
"import os;"
|
||||
"import pyautogui;"
|
||||
"import pyperclip;"
|
||||
"import subprocess;"
|
||||
"import time;"
|
||||
"p_http = os.environ.get('http_proxy') or os.environ.get('HTTP_PROXY');"
|
||||
"p_https = os.environ.get('https_proxy') or os.environ.get('HTTPS_PROXY');"
|
||||
"proxy_prefix = (f'http_proxy={p_http} ' if p_http else '') + (f'https_proxy={p_https} ' if p_https else '');"
|
||||
f"subprocess.run(f'echo \"{self.client_password}\" | sudo -S {{proxy_prefix}}apt-get install -y xclip xsel', shell=True, check=True);"
|
||||
)
|
||||
|
||||
|
||||
|
||||
click_coords = None
|
||||
if element_description is not None:
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
click_coords = [x, y]
|
||||
|
||||
commands += f"pyautogui.click({x}, {y});"
|
||||
|
||||
if overwrite:
|
||||
commands += (
|
||||
f"pyautogui.hotkey('ctrl', 'a');"
|
||||
"pyautogui.press('backspace');"
|
||||
)
|
||||
|
||||
# use paste to input
|
||||
commands += (
|
||||
"original_clipboard = pyperclip.paste();"
|
||||
f"pyperclip.copy({repr(text)});"
|
||||
"pyautogui.hotkey('ctrl', 'v');"
|
||||
"pyperclip.copy(original_clipboard);"
|
||||
)
|
||||
|
||||
if enter:
|
||||
commands += "pyautogui.press('enter');"
|
||||
|
||||
if click_coords is not None:
|
||||
return (commands, click_coords)
|
||||
else:
|
||||
return commands
|
||||
|
||||
@searcher_agent_action
|
||||
def scroll(self, element_description: str, clicks: int, shift: bool = False):
|
||||
"""Scroll the element in the specified direction
|
||||
Args:
|
||||
element_description:str, a very detailed description of which element to enter scroll in. This description should be at least a full sentence.
|
||||
clicks:int, the number of clicks to scroll can be positive (up) or negative (down).
|
||||
shift:bool, whether to use shift+scroll for horizontal scrolling
|
||||
"""
|
||||
x, y = self.grounder_agent.generate_coords(element_description, self.obs)
|
||||
|
||||
if shift:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.hscroll({clicks})", [x, y])
|
||||
else:
|
||||
return (f"import pyautogui; import time; pyautogui.moveTo({x}, {y}); time.sleep(0.5); pyautogui.vscroll({clicks})", [x, y])
|
||||
|
||||
@searcher_agent_action
|
||||
def hotkey(self, keys: List):
|
||||
"""Press a hotkey combination (can press a single key as well)
|
||||
Args:
|
||||
keys: List the keys to press in combination in a list format (e.g. ['ctrl', 'c'], ['enter'])
|
||||
"""
|
||||
# add quotes around the keys
|
||||
keys = [f"'{key}'" for key in keys]
|
||||
return f"import pyautogui; pyautogui.hotkey({', '.join(keys)})"
|
||||
|
||||
@searcher_agent_action
|
||||
def save_to_tutorial_notes(self, text: str):
|
||||
"""Save high quality and useful information to a long-term knowledge bank for reuse during this search task.
|
||||
Args:
|
||||
text:str, the text to save to the tutorial notes
|
||||
"""
|
||||
self.tutorial_notes.append(text)
|
||||
return """WAIT"""
|
||||
|
||||
@searcher_agent_action
|
||||
def wait(self, time: float):
|
||||
"""Wait for a specified amount of time
|
||||
Args:
|
||||
time:float the amount of time to wait in seconds
|
||||
"""
|
||||
return f"""import time; time.sleep({time})"""
|
||||
|
||||
@searcher_agent_action
|
||||
def done(
|
||||
self,
|
||||
tutorial: str
|
||||
):
|
||||
"""End the current task with a success. Use this when you believe the entire task has been fully completed.
|
||||
Args:
|
||||
tutorial:str, A detailed, step-by-step tutorial compiled from the search results to be passed to the main agent.
|
||||
"""
|
||||
self.tutorial_or_hint = tutorial
|
||||
return """DONE"""
|
||||
|
||||
@searcher_agent_action
|
||||
def fail(
|
||||
self,
|
||||
hint: str
|
||||
):
|
||||
"""End the current task with a failure. Use this when you believe the entire task is impossible to complete.
|
||||
Args:
|
||||
hint:str, A hint or reason explaining why the search failed, or what kind of information was missing.
|
||||
"""
|
||||
self.tutorial_or_hint = hint
|
||||
return """FAIL"""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user