* Added a **pyproject.toml** file to define project metadata and dependencies. * Added **run\_maestro.py** and **osworld\_run\_maestro.py** to provide the main execution logic. * Introduced multiple new modules, including **Evaluator**, **Controller**, **Manager**, and **Sub-Worker**, supporting task planning, state management, and data analysis. * Added a **tools module** containing utility functions and tool configurations to improve code reusability. * Updated the **README** and documentation with usage examples and module descriptions. These changes lay the foundation for expanding the Maestro project’s functionality and improving the user experience. Co-authored-by: Hiroid <guoliangxuan@deepmatrix.com>
826 lines
31 KiB
Python
826 lines
31 KiB
Python
"""
|
|
Tools module for GUI agents.
|
|
|
|
This module provides various tools for GUI agents to perform tasks such as web search,
|
|
context fusion, subtask planning, trajectory reflection, memory retrieval, grounding,
|
|
evaluation, and action generation.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import base64
|
|
import requests
|
|
import time
|
|
from typing import Dict, Any, Optional, List, Union, Tuple
|
|
from abc import ABC, abstractmethod
|
|
import logging
|
|
from ..core.mllm import LLMAgent, WebSearchAgent, EmbeddingAgent
|
|
import threading
|
|
from ..prompts import get_prompt, module
|
|
|
|
logger = logging.getLogger("desktopenv.tools")
|
|
|
|
class BaseTool(ABC):
|
|
"""Base class for all tools."""
|
|
_prompts_dict = None
|
|
_prompts_dict_lock = threading.Lock()
|
|
# Directory retained for backward compatibility; no longer scanned directly
|
|
_prompts_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompts")
|
|
|
|
@classmethod
|
|
def _load_prompts_dict(cls):
|
|
# Deprecated: kept for compatibility if other code accesses _prompts_dict.
|
|
# Now pull prompts via the registry to avoid direct filesystem coupling.
|
|
if cls._prompts_dict is None:
|
|
with cls._prompts_dict_lock:
|
|
if cls._prompts_dict is None:
|
|
cls._prompts_dict = {}
|
|
|
|
def __init__(self, provider: str, model_name: str, tool_name: str):
|
|
"""
|
|
Initialize the base tool.
|
|
Args:
|
|
provider: API provider name (e.g., "gemini", "openai")
|
|
model_name: Model name to use (e.g., "gemini-2.5-pro")
|
|
tool_name: Name of the tool (used as key in prompts files)
|
|
"""
|
|
self.provider = provider
|
|
self.model_name = model_name
|
|
self.tool_name = tool_name
|
|
self._load_prompts_dict()
|
|
self._prompt_template = self._get_prompt_template()
|
|
# Create LLMAgent instance for tool usage
|
|
self.engine_params = {
|
|
"engine_type": provider,
|
|
"model": model_name
|
|
}
|
|
self.llm_agent = LLMAgent(engine_params=self.engine_params, system_prompt=self._prompt_template)
|
|
|
|
def _get_prompt_template(self) -> str:
|
|
if self.tool_name is None:
|
|
return ""
|
|
# Prefer reading prompt text directly from gui_agents.prompts.module
|
|
try:
|
|
prompt_category_map = {
|
|
# manager prompts
|
|
"query_formulator": ("manager", "query_formulator"),
|
|
"narrative_summarization": ("manager", "narrative_summarization"),
|
|
"context_fusion": ("manager", "context_fusion"),
|
|
"planner_role": ("manager", "planner_role"),
|
|
"supplement_role": ("manager", "supplement_role"),
|
|
"dag_translator": ("manager", "dag_translator"),
|
|
"objective_alignment": ("manager", "objective_alignment"),
|
|
# worker prompts
|
|
"operator_role": ("worker", "operator_role"),
|
|
"technician_role": ("worker", "technician_role"),
|
|
"analyst_role": ("worker", "analyst_role"),
|
|
"grounding": ("worker", "grounding"),
|
|
"text_span": ("worker", "text_span"),
|
|
"episode_summarization": ("worker", "episode_summarization"),
|
|
# evaluator prompts
|
|
"worker_success_role": ("evaluator", "worker_success_role"),
|
|
"worker_stale_role": ("evaluator", "worker_stale_role"),
|
|
"periodic_role": ("evaluator", "periodic_role"),
|
|
"final_check_role": ("evaluator", "final_check_role"),
|
|
}
|
|
|
|
# Tools that should be prefixed with system architecture info
|
|
tools_require_system_prefix = {
|
|
"planner_role",
|
|
"supplement_role",
|
|
"dag_translator",
|
|
"operator_role",
|
|
"technician_role",
|
|
"analyst_role",
|
|
"worker_success_role",
|
|
"worker_stale_role",
|
|
"periodic_role",
|
|
"final_check_role",
|
|
"objective_alignment",
|
|
}
|
|
|
|
category_tuple = prompt_category_map.get(self.tool_name)
|
|
|
|
prompt_text = ""
|
|
if category_tuple is None:
|
|
# Try root-level attribute on module (e.g., system_architecture)
|
|
if hasattr(module, self.tool_name):
|
|
prompt_text = getattr(module, self.tool_name)
|
|
else:
|
|
return ""
|
|
else:
|
|
category_name, key_name = category_tuple
|
|
category_obj = getattr(module, category_name, None)
|
|
if category_obj is None:
|
|
return ""
|
|
value = getattr(category_obj, key_name, None)
|
|
if isinstance(value, str) and value:
|
|
prompt_text = value
|
|
else:
|
|
return ""
|
|
|
|
# Optionally prefix with system architecture information for selected tools
|
|
if (
|
|
isinstance(prompt_text, str)
|
|
and prompt_text
|
|
and self.tool_name in tools_require_system_prefix
|
|
):
|
|
system_info = getattr(module, "system_architecture", "")
|
|
if isinstance(system_info, str) and system_info:
|
|
return f"{system_info}\n\n{prompt_text}"
|
|
|
|
return prompt_text
|
|
except Exception:
|
|
# Fallback to registry to allow central overrides if available
|
|
return ""
|
|
|
|
def _call_lmm(self, input_data: Dict[str, Any], temperature: float = 0.0):
|
|
"""
|
|
Call the LMM model for inference using the prompt template with retry mechanism
|
|
|
|
Args:
|
|
input_data: Dictionary containing input data to format the prompt template
|
|
temperature: Temperature parameter to control randomness of output
|
|
|
|
Returns:
|
|
Model response as text
|
|
"""
|
|
# self.llm_agent.reset()
|
|
|
|
# Extract text and image inputs
|
|
text_input = input_data.get('str_input', '')
|
|
image_input = input_data.get('img_input', None)
|
|
|
|
# Add the message with the formatted prompt
|
|
self.llm_agent.reset()
|
|
self.llm_agent.add_message(text_input, image_content=image_input, role="user")
|
|
|
|
# Implement safe retry mechanism
|
|
max_retries = 3
|
|
attempt = 0
|
|
content, total_tokens, cost_string = "", [0, 0, 0], ""
|
|
|
|
while attempt < max_retries:
|
|
try:
|
|
content, total_tokens, cost_string = self.llm_agent.get_response(temperature=temperature)
|
|
break # If successful, break out of the loop
|
|
except Exception as e:
|
|
attempt += 1
|
|
logger.error(f"LLM call attempt {attempt} failed: {str(e)}")
|
|
if attempt == max_retries:
|
|
logger.error("Max retries reached. Returning error message.")
|
|
return f"Error: LLM call failed after {max_retries} attempts: {str(e)}", [0, 0, 0], ""
|
|
time.sleep(1.0)
|
|
return content, total_tokens, cost_string
|
|
|
|
@abstractmethod
|
|
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
|
|
"""
|
|
Execute the tool with the given input.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the input for the tool
|
|
Expected to have 'str_input' and/or 'img_input' keys
|
|
|
|
Returns:
|
|
The output of the tool as a string
|
|
"""
|
|
pass
|
|
|
|
|
|
class ToolFactory:
|
|
"""Factory class for creating tools."""
|
|
|
|
@staticmethod
|
|
def create_tool(tool_name: str, provider: str, model_name: str, **kwargs) -> 'BaseTool':
|
|
"""
|
|
Create a tool instance based on the tool name.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to create
|
|
provider: API provider name
|
|
model_name: Model name to use
|
|
**kwargs: Additional parameters to pass to the tool
|
|
|
|
Returns:
|
|
An instance of the specified tool
|
|
|
|
Raises:
|
|
ValueError: If the tool name is not recognized
|
|
"""
|
|
tool_map = {
|
|
"embedding": (EmbeddingTool, None), # all
|
|
|
|
"query_formulator": (QueryFormulatorTool, "query_formulator"), # manager
|
|
"websearch": (WebSearchTool, None), # manager
|
|
"narrative_summarization": (NarrativeSummarizationTool, "narrative_summarization"), # manager
|
|
"context_fusion": (ContextFusionTool, "context_fusion"), # manager
|
|
"planner_role": (SubtaskPlannerTool, "planner_role"), # manager
|
|
"supplement_role": (SubtaskPlannerTool, "supplement_role"), # manager
|
|
"dag_translator": (DAGTranslatorTool, "dag_translator"), # manager
|
|
"objective_alignment": (ObjectiveAlignmentTool, "objective_alignment"), # manager
|
|
|
|
"operator_role": (ActionGeneratorTool, "operator_role"), # worker
|
|
"technician_role": (ActionGeneratorTool, "technician_role"), # worker
|
|
"analyst_role": (ActionGeneratorTool, "analyst_role"), # worker
|
|
"grounding": (GroundingTool, "grounding"), # worker
|
|
"text_span": (TextSpanTool, "text_span"), # worker
|
|
"episode_summarization": (EpisodeSummarizationTool, "episode_summarization"), # worker
|
|
|
|
"worker_success_role": (EvaluatorTool, "worker_success_role"), # evaluator
|
|
"worker_stale_role": (EvaluatorTool, "worker_stale_role"), # evaluator
|
|
"periodic_role": (EvaluatorTool, "periodic_role"), # evaluator
|
|
"final_check_role": (EvaluatorTool, "final_check_role"), # evaluator
|
|
}
|
|
|
|
if tool_name not in tool_map:
|
|
raise ValueError(f"Unknown tool name: {tool_name}")
|
|
|
|
tool_class, prompt_key = tool_map[tool_name]
|
|
|
|
# WebSearchTool and EmbeddingTool don't need a prompt
|
|
if tool_name == "websearch":
|
|
return tool_class(provider, model_name, None, **kwargs)
|
|
if tool_name == "embedding":
|
|
return tool_class(provider, model_name, None, **kwargs)
|
|
|
|
return tool_class(provider, model_name, prompt_key, **kwargs)
|
|
|
|
|
|
class WebSearchTool(BaseTool):
|
|
"""Tool for performing web searches."""
|
|
|
|
def __init__(self, provider: str, model_name: str, tool_name: str):
|
|
"""
|
|
Initialize the web search tool.
|
|
|
|
Args:
|
|
provider: API provider name (e.g., "bocha", "exa")
|
|
model_name: Model name to use (not used for WebSearchAgent)
|
|
tool_name: Name of the tool (used as key in prompts.json)
|
|
"""
|
|
self.provider = provider
|
|
|
|
# Create WebSearchAgent instance for search
|
|
self.engine_params = {
|
|
"engine_type": provider,
|
|
"model": model_name,
|
|
}
|
|
|
|
# Initialize WebSearchAgent
|
|
self.search_agent = WebSearchAgent(engine_params=self.engine_params)
|
|
|
|
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
|
|
"""
|
|
Execute a web search with the given query.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the search query
|
|
Expected to have 'str_input' key with the search query
|
|
|
|
Returns:
|
|
Search results as a string
|
|
"""
|
|
query = tool_input.get('str_input', '')
|
|
if not query:
|
|
return "Error: No search query provided", [0, 0, 0], ""
|
|
|
|
try:
|
|
# Get the answer from the search results
|
|
answer, total_tokens, cost = self.search_agent.get_answer(query)
|
|
|
|
# Return just the answer
|
|
return answer, total_tokens, cost # type: ignore
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during web search: {str(e)}")
|
|
return f"Error: Web search failed: {str(e)}", [0, 0, 0], ""
|
|
|
|
|
|
class ContextFusionTool(BaseTool):
|
|
"""Tool for fusing multiple contexts together."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Fuse multiple contexts together.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the contexts to fuse
|
|
Expected to have 'str_input' key with JSON-formatted contexts
|
|
|
|
Returns:
|
|
Fused context as a string
|
|
"""
|
|
contexts = tool_input.get('str_input', '')
|
|
if not contexts:
|
|
return "Error: No contexts provided"
|
|
|
|
# Use the prompt template and LMM for context fusion
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class SubtaskPlannerTool(BaseTool):
|
|
"""Tool for planning subtasks."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Plan subtasks for a given task.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the task description
|
|
Expected to have 'str_input' key with the task description
|
|
May also have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Subtask plan as a string
|
|
"""
|
|
task = tool_input.get('str_input', '')
|
|
if not task:
|
|
return "Error: No task description provided"
|
|
|
|
# Use the prompt template and LMM for subtask planning
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class NarrativeSummarizationTool(BaseTool):
|
|
"""Tool for summarizing narrative memories."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Summarize narrative memories.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the narrative memory data
|
|
Expected to have 'str_input' key with the narrative memory data
|
|
May also have 'img_input' key with relevant images
|
|
|
|
Returns:
|
|
Summarized narrative as a string
|
|
"""
|
|
narrative_data = tool_input.get('str_input', '')
|
|
if not narrative_data:
|
|
return "Error: No narrative memory data provided"
|
|
|
|
# Use the prompt template and LMM for narrative summarization
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class EpisodeSummarizationTool(BaseTool):
|
|
"""Tool for summarizing episodic memories."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Summarize episodic memories.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the episodic memory data
|
|
Expected to have 'str_input' key with the episodic memory data
|
|
May also have 'img_input' key with relevant images
|
|
|
|
Returns:
|
|
Summarized episode as a string
|
|
"""
|
|
episode_data = tool_input.get('str_input', '')
|
|
if not episode_data:
|
|
return "Error: No episodic memory data provided"
|
|
|
|
# Use the prompt template and LMM for episode summarization
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class TextSpanTool(BaseTool):
|
|
"""Tool for processing text spans."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Process text spans for a given input.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the text input
|
|
Expected to have 'str_input' key with the text content
|
|
May also have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Processed text spans as a string
|
|
"""
|
|
text = tool_input.get('str_input', '')
|
|
if not text:
|
|
return "Error: No text content provided"
|
|
|
|
# Use the prompt template and LMM for text span processing
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class DAGTranslatorTool(BaseTool):
|
|
"""Tool for translating task descriptions into a DAG (Directed Acyclic Graph) structure."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Translate task descriptions into a DAG structure.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the task description
|
|
Expected to have 'str_input' key with the task description
|
|
May also have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
DAG representation as a string
|
|
"""
|
|
task = tool_input.get('str_input', '')
|
|
if not task:
|
|
return "Error: No task description provided"
|
|
|
|
# Use the prompt template and LMM for DAG translation
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class ObjectiveAlignmentTool(BaseTool):
|
|
"""Tool for aligning and rewriting user objective with current screen context."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Align ambiguous or high-level user objective with the current desktop screenshot context
|
|
and output a refined objective and assumptions.
|
|
|
|
Args:
|
|
tool_input: Dict with keys:
|
|
- 'str_input': the raw user objective or context text
|
|
- 'img_input': optional screenshot image content
|
|
|
|
Returns:
|
|
Refined objective as text (ideally JSON-structured), token count, and cost string
|
|
"""
|
|
text = tool_input.get('str_input', '')
|
|
if not text:
|
|
return "Error: No objective text provided", [0, 0, 0], ""
|
|
# Forward to LMM with the prompt template
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class TrajReflectorTool(BaseTool):
|
|
"""Tool for reflecting on execution trajectories."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Reflect on an execution trajectory.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the trajectory
|
|
Expected to have 'str_input' key with the trajectory
|
|
|
|
Returns:
|
|
Reflection as a string
|
|
"""
|
|
trajectory = tool_input.get('str_input', '')
|
|
if not trajectory:
|
|
return "Error: No trajectory provided"
|
|
|
|
# Use the prompt template and LMM for trajectory reflection
|
|
return self._call_lmm(tool_input)
|
|
|
|
class GroundingTool(BaseTool):
|
|
"""Tool for grounding agent actions in the environment."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Ground agent actions in the environment.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the action and environment state
|
|
Expected to have 'str_input' key with the action
|
|
Expected to have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Grounded action as a string
|
|
"""
|
|
action = tool_input.get('str_input', '')
|
|
screenshot = tool_input.get('img_input')
|
|
|
|
if not action:
|
|
return "Error: No action provided"
|
|
if not screenshot:
|
|
return "Error: No screenshot provided"
|
|
|
|
# Use the prompt template and LMM for action grounding
|
|
return self._call_lmm(tool_input)
|
|
|
|
def get_grounding_wh(self):
|
|
"""
|
|
Get grounding width and height based on provider and model name.
|
|
|
|
Returns:
|
|
If provider is doubao and model_name contains 'ui-tars', returns two values:
|
|
grounding_width (int): Width value (1024)
|
|
grounding_height (int): Height value (768)
|
|
Otherwise returns None, None
|
|
"""
|
|
if self.provider == "doubao" and ("ui-tars" in self.model_name or "ep-" in self.model_name):
|
|
grounding_width = 1000
|
|
grounding_height = 1000
|
|
return grounding_width, grounding_height
|
|
return None, None
|
|
|
|
|
|
class EvaluatorTool(BaseTool):
|
|
"""Tool for evaluating agent performance."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Evaluate agent performance.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the evaluation data
|
|
Expected to have 'str_input' key with the evaluation data
|
|
|
|
Returns:
|
|
Evaluation result as a string
|
|
"""
|
|
eval_data = tool_input.get('str_input', '')
|
|
if not eval_data:
|
|
return "Error: No evaluation data provided"
|
|
|
|
# Use the prompt template and LMM for performance evaluation
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class ActionGeneratorTool(BaseTool):
|
|
"""Tool for generating executable actions."""
|
|
|
|
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
|
|
"""
|
|
Initialize the action generator tool.
|
|
|
|
Args:
|
|
provider: API provider name
|
|
model_name: Model name to use
|
|
tool_name: Name of the tool (used as key in prompts.json)
|
|
**kwargs: Additional parameters, including:
|
|
enable_search: Whether to enable web search functionality
|
|
search_provider: Provider for web search (defaults to "bocha")
|
|
search_model: Model for web search (defaults to "")
|
|
"""
|
|
super().__init__(provider, model_name, tool_name)
|
|
|
|
# Extract search-related parameters
|
|
self.enable_search = kwargs.get("enable_search", False)
|
|
search_provider = kwargs.get("search_provider", "bocha")
|
|
search_model = kwargs.get("search_model", "")
|
|
|
|
# Initialize search tool if enabled
|
|
self.search_tool = None
|
|
if self.enable_search:
|
|
self.search_tool = WebSearchTool(search_provider, search_model, "")
|
|
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Generate executable actions.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the action request
|
|
Expected to have 'str_input' key with the action request
|
|
May also have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Generated action as a string
|
|
"""
|
|
action_request = tool_input.get('str_input', '')
|
|
if not action_request:
|
|
return "Error: No action request provided", [0, 0, 0], ""
|
|
|
|
# Check if search is enabled
|
|
if self.enable_search and self.search_tool:
|
|
try:
|
|
# Use the input text directly as search query
|
|
search_query = action_request
|
|
logger.info(f"Performing web search for query: {search_query}")
|
|
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
|
|
|
|
# Enhance the action request with search results
|
|
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
|
|
tool_input["str_input"] = enhanced_request
|
|
|
|
logger.info(f"Search completed. Found information: {len(search_results)} characters")
|
|
except Exception as e:
|
|
logger.error(f"Error during web search: {e}")
|
|
# Continue with original request if search fails
|
|
|
|
# Use the prompt template and LMM for action generation
|
|
return self._call_lmm(tool_input)
|
|
|
|
|
|
class FastActionGeneratorTool(BaseTool):
|
|
"""Tool for directly generating executable actions without intermediate planning."""
|
|
|
|
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
|
|
"""
|
|
Initialize the fast action generator tool.
|
|
|
|
Args:
|
|
provider: API provider name
|
|
model_name: Model name to use
|
|
tool_name: Name of the tool (used as key in prompts.json)
|
|
**kwargs: Additional parameters, including:
|
|
enable_search: Whether to enable web search functionality
|
|
search_provider: Provider for web search (defaults to "bocha")
|
|
search_model: Model for web search (defaults to "")
|
|
"""
|
|
super().__init__(provider, model_name, tool_name)
|
|
|
|
# Extract search-related parameters
|
|
self.enable_search = kwargs.get("enable_search", False)
|
|
search_provider = kwargs.get("search_provider", "bocha")
|
|
search_model = kwargs.get("search_model", "")
|
|
|
|
# Initialize search tool if enabled
|
|
self.search_tool = None
|
|
if self.enable_search:
|
|
self.search_tool = WebSearchTool(search_provider, search_model, "")
|
|
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Generate executable actions directly from the instruction and screenshot.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the action request
|
|
Expected to have 'str_input' key with the instruction
|
|
Expected to have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Generated action as a string, token count, and cost
|
|
"""
|
|
action_request = tool_input.get('str_input', '')
|
|
screenshot = tool_input.get('img_input')
|
|
if not action_request:
|
|
return "Error: No action request provided", [0, 0, 0], ""
|
|
if not screenshot:
|
|
return "Error: No screenshot provided", [0, 0, 0], ""
|
|
# Check if search is enabled
|
|
if self.enable_search and self.search_tool:
|
|
try:
|
|
# Use the input text directly as search query
|
|
search_query = action_request
|
|
logger.info(f"Performing web search for query: {search_query}")
|
|
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
|
|
|
|
# Enhance the action request with search results
|
|
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
|
|
tool_input["str_input"] = enhanced_request
|
|
|
|
logger.info(f"Search completed. Found information: {len(search_results)} characters")
|
|
except Exception as e:
|
|
logger.error(f"Error during web search: {e}")
|
|
# Continue with original request if search fails
|
|
|
|
# Use the prompt template and LMM for action generation
|
|
return self._call_lmm(tool_input)
|
|
|
|
def get_grounding_wh(self):
|
|
"""
|
|
Get grounding width and height based on provider and model name.
|
|
|
|
Returns:
|
|
If provider is doubao and model_name contains 'ui-tars', returns two values:
|
|
grounding_width (int): Width value (1024)
|
|
grounding_height (int): Height value (768)
|
|
Otherwise returns None, None
|
|
"""
|
|
if self.provider == "doubao" and "ui-tars" in self.model_name:
|
|
grounding_width = 1000
|
|
grounding_height = 1000
|
|
return grounding_width, grounding_height
|
|
return None, None
|
|
|
|
class EmbeddingTool(BaseTool):
|
|
"""Tool for generating text embeddings."""
|
|
|
|
def __init__(self, provider: str, model_name: str, tool_name: str):
|
|
"""
|
|
Initialize the embedding tool.
|
|
|
|
Args:
|
|
provider: API provider name (e.g., "openai", "gemini")
|
|
model_name: Model name to use
|
|
tool_name: Name of the tool (used as key in prompts.json)
|
|
"""
|
|
self.provider = provider
|
|
self.model_name = model_name
|
|
self.tool_name = tool_name
|
|
|
|
# Create EmbeddingAgent instance
|
|
self.engine_params = {
|
|
"engine_type": provider,
|
|
"embedding_model": model_name
|
|
}
|
|
|
|
# Initialize EmbeddingAgent
|
|
self.embedding_agent = EmbeddingAgent(engine_params=self.engine_params)
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Generate embeddings for the given text.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the text to embed
|
|
Expected to have 'str_input' key with the text
|
|
|
|
Returns:
|
|
Embeddings as a JSON string
|
|
"""
|
|
text = tool_input.get('str_input', '')
|
|
|
|
if not text:
|
|
return "Error: No text provided for embedding", [0, 0, 0], ""
|
|
|
|
try:
|
|
# Get embeddings for the text
|
|
embeddings, total_tokens, cost_string = self.embedding_agent.get_embeddings(text)
|
|
return embeddings, total_tokens, cost_string
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during embedding operation: {str(e)}")
|
|
return f"Error: Embedding operation failed: {str(e)}", [0, 0, 0], ""
|
|
|
|
class QueryFormulatorTool(BaseTool):
|
|
"""Tool for formulating queries from tasks or contexts."""
|
|
|
|
def execute(self, tool_input: Dict[str, Any]):
|
|
"""
|
|
Formulate a query for a given task or context.
|
|
|
|
Args:
|
|
tool_input: Dictionary containing the task or context description
|
|
Expected to have 'str_input' key with the description
|
|
May also have 'img_input' key with a screenshot
|
|
|
|
Returns:
|
|
Formulated query as a string
|
|
"""
|
|
task = tool_input.get('str_input', '')
|
|
if not task:
|
|
return "Error: No task or context description provided"
|
|
|
|
# Use the prompt template and LMM for query formulation
|
|
return self._call_lmm(tool_input)
|
|
|
|
class NewTools:
|
|
"""Main Tools class that provides access to all available tools."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the Tools class."""
|
|
self.tools = {}
|
|
|
|
def register_tool(self, tool_name: str, provider: str, model_name: str, **kwargs):
|
|
"""
|
|
Register a tool with the specified parameters.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to register
|
|
provider: API provider name
|
|
model_name: Model name to use
|
|
**kwargs: Additional parameters to pass to the tool
|
|
"""
|
|
tool: BaseTool = ToolFactory.create_tool(tool_name, provider, model_name, **kwargs)
|
|
self.tools[tool_name] = tool
|
|
|
|
def execute_tool(self, tool_name: str, tool_input: Dict[str, Any]):
|
|
"""
|
|
Execute a tool with the given input.
|
|
|
|
Args:
|
|
tool_name: Name of the tool to execute
|
|
tool_input: Input for the tool
|
|
|
|
Returns:
|
|
The output of the tool as a string
|
|
|
|
Raises:
|
|
ValueError: If the tool is not registered
|
|
"""
|
|
if tool_name not in self.tools:
|
|
raise ValueError(f"Tool {tool_name} is not registered")
|
|
|
|
return self.tools[tool_name].execute(tool_input)
|
|
|
|
def reset(self, tool_name: Optional[str] = None):
|
|
"""
|
|
Reset tools by resetting their llm_agent if available.
|
|
|
|
Args:
|
|
tool_name: Optional name of the specific tool to reset. If None, resets all tools.
|
|
"""
|
|
if tool_name is not None:
|
|
# Reset a specific tool
|
|
if tool_name not in self.tools:
|
|
raise ValueError(f"Tool {tool_name} is not registered")
|
|
|
|
tool = self.tools[tool_name]
|
|
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
|
|
tool.llm_agent.reset()
|
|
else:
|
|
# Reset all tools
|
|
for tool in self.tools.values():
|
|
# Only reset if the tool has an llm_agent attribute
|
|
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
|
|
tool.llm_agent.reset() |