Add multiple new modules and tools to enhance the functionality and extensibility of the Maestro project (#333)
* Added a **pyproject.toml** file to define project metadata and dependencies. * Added **run\_maestro.py** and **osworld\_run\_maestro.py** to provide the main execution logic. * Introduced multiple new modules, including **Evaluator**, **Controller**, **Manager**, and **Sub-Worker**, supporting task planning, state management, and data analysis. * Added a **tools module** containing utility functions and tool configurations to improve code reusability. * Updated the **README** and documentation with usage examples and module descriptions. These changes lay the foundation for expanding the Maestro project’s functionality and improving the user experience. Co-authored-by: Hiroid <guoliangxuan@deepmatrix.com>
This commit is contained in:
826
mm_agents/maestro/tools/new_tools.py
Normal file
826
mm_agents/maestro/tools/new_tools.py
Normal file
@@ -0,0 +1,826 @@
|
||||
"""
|
||||
Tools module for GUI agents.
|
||||
|
||||
This module provides various tools for GUI agents to perform tasks such as web search,
|
||||
context fusion, subtask planning, trajectory reflection, memory retrieval, grounding,
|
||||
evaluation, and action generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import requests
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List, Union, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
import logging
|
||||
from ..core.mllm import LLMAgent, WebSearchAgent, EmbeddingAgent
|
||||
import threading
|
||||
from ..prompts import get_prompt, module
|
||||
|
||||
logger = logging.getLogger("desktopenv.tools")
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""Base class for all tools."""
|
||||
_prompts_dict = None
|
||||
_prompts_dict_lock = threading.Lock()
|
||||
# Directory retained for backward compatibility; no longer scanned directly
|
||||
_prompts_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "prompts")
|
||||
|
||||
@classmethod
|
||||
def _load_prompts_dict(cls):
|
||||
# Deprecated: kept for compatibility if other code accesses _prompts_dict.
|
||||
# Now pull prompts via the registry to avoid direct filesystem coupling.
|
||||
if cls._prompts_dict is None:
|
||||
with cls._prompts_dict_lock:
|
||||
if cls._prompts_dict is None:
|
||||
cls._prompts_dict = {}
|
||||
|
||||
def __init__(self, provider: str, model_name: str, tool_name: str):
|
||||
"""
|
||||
Initialize the base tool.
|
||||
Args:
|
||||
provider: API provider name (e.g., "gemini", "openai")
|
||||
model_name: Model name to use (e.g., "gemini-2.5-pro")
|
||||
tool_name: Name of the tool (used as key in prompts files)
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model_name = model_name
|
||||
self.tool_name = tool_name
|
||||
self._load_prompts_dict()
|
||||
self._prompt_template = self._get_prompt_template()
|
||||
# Create LLMAgent instance for tool usage
|
||||
self.engine_params = {
|
||||
"engine_type": provider,
|
||||
"model": model_name
|
||||
}
|
||||
self.llm_agent = LLMAgent(engine_params=self.engine_params, system_prompt=self._prompt_template)
|
||||
|
||||
def _get_prompt_template(self) -> str:
|
||||
if self.tool_name is None:
|
||||
return ""
|
||||
# Prefer reading prompt text directly from gui_agents.prompts.module
|
||||
try:
|
||||
prompt_category_map = {
|
||||
# manager prompts
|
||||
"query_formulator": ("manager", "query_formulator"),
|
||||
"narrative_summarization": ("manager", "narrative_summarization"),
|
||||
"context_fusion": ("manager", "context_fusion"),
|
||||
"planner_role": ("manager", "planner_role"),
|
||||
"supplement_role": ("manager", "supplement_role"),
|
||||
"dag_translator": ("manager", "dag_translator"),
|
||||
"objective_alignment": ("manager", "objective_alignment"),
|
||||
# worker prompts
|
||||
"operator_role": ("worker", "operator_role"),
|
||||
"technician_role": ("worker", "technician_role"),
|
||||
"analyst_role": ("worker", "analyst_role"),
|
||||
"grounding": ("worker", "grounding"),
|
||||
"text_span": ("worker", "text_span"),
|
||||
"episode_summarization": ("worker", "episode_summarization"),
|
||||
# evaluator prompts
|
||||
"worker_success_role": ("evaluator", "worker_success_role"),
|
||||
"worker_stale_role": ("evaluator", "worker_stale_role"),
|
||||
"periodic_role": ("evaluator", "periodic_role"),
|
||||
"final_check_role": ("evaluator", "final_check_role"),
|
||||
}
|
||||
|
||||
# Tools that should be prefixed with system architecture info
|
||||
tools_require_system_prefix = {
|
||||
"planner_role",
|
||||
"supplement_role",
|
||||
"dag_translator",
|
||||
"operator_role",
|
||||
"technician_role",
|
||||
"analyst_role",
|
||||
"worker_success_role",
|
||||
"worker_stale_role",
|
||||
"periodic_role",
|
||||
"final_check_role",
|
||||
"objective_alignment",
|
||||
}
|
||||
|
||||
category_tuple = prompt_category_map.get(self.tool_name)
|
||||
|
||||
prompt_text = ""
|
||||
if category_tuple is None:
|
||||
# Try root-level attribute on module (e.g., system_architecture)
|
||||
if hasattr(module, self.tool_name):
|
||||
prompt_text = getattr(module, self.tool_name)
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
category_name, key_name = category_tuple
|
||||
category_obj = getattr(module, category_name, None)
|
||||
if category_obj is None:
|
||||
return ""
|
||||
value = getattr(category_obj, key_name, None)
|
||||
if isinstance(value, str) and value:
|
||||
prompt_text = value
|
||||
else:
|
||||
return ""
|
||||
|
||||
# Optionally prefix with system architecture information for selected tools
|
||||
if (
|
||||
isinstance(prompt_text, str)
|
||||
and prompt_text
|
||||
and self.tool_name in tools_require_system_prefix
|
||||
):
|
||||
system_info = getattr(module, "system_architecture", "")
|
||||
if isinstance(system_info, str) and system_info:
|
||||
return f"{system_info}\n\n{prompt_text}"
|
||||
|
||||
return prompt_text
|
||||
except Exception:
|
||||
# Fallback to registry to allow central overrides if available
|
||||
return ""
|
||||
|
||||
def _call_lmm(self, input_data: Dict[str, Any], temperature: float = 0.0):
|
||||
"""
|
||||
Call the LMM model for inference using the prompt template with retry mechanism
|
||||
|
||||
Args:
|
||||
input_data: Dictionary containing input data to format the prompt template
|
||||
temperature: Temperature parameter to control randomness of output
|
||||
|
||||
Returns:
|
||||
Model response as text
|
||||
"""
|
||||
# self.llm_agent.reset()
|
||||
|
||||
# Extract text and image inputs
|
||||
text_input = input_data.get('str_input', '')
|
||||
image_input = input_data.get('img_input', None)
|
||||
|
||||
# Add the message with the formatted prompt
|
||||
self.llm_agent.reset()
|
||||
self.llm_agent.add_message(text_input, image_content=image_input, role="user")
|
||||
|
||||
# Implement safe retry mechanism
|
||||
max_retries = 3
|
||||
attempt = 0
|
||||
content, total_tokens, cost_string = "", [0, 0, 0], ""
|
||||
|
||||
while attempt < max_retries:
|
||||
try:
|
||||
content, total_tokens, cost_string = self.llm_agent.get_response(temperature=temperature)
|
||||
break # If successful, break out of the loop
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
logger.error(f"LLM call attempt {attempt} failed: {str(e)}")
|
||||
if attempt == max_retries:
|
||||
logger.error("Max retries reached. Returning error message.")
|
||||
return f"Error: LLM call failed after {max_retries} attempts: {str(e)}", [0, 0, 0], ""
|
||||
time.sleep(1.0)
|
||||
return content, total_tokens, cost_string
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
|
||||
"""
|
||||
Execute the tool with the given input.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the input for the tool
|
||||
Expected to have 'str_input' and/or 'img_input' keys
|
||||
|
||||
Returns:
|
||||
The output of the tool as a string
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ToolFactory:
|
||||
"""Factory class for creating tools."""
|
||||
|
||||
@staticmethod
|
||||
def create_tool(tool_name: str, provider: str, model_name: str, **kwargs) -> 'BaseTool':
|
||||
"""
|
||||
Create a tool instance based on the tool name.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to create
|
||||
provider: API provider name
|
||||
model_name: Model name to use
|
||||
**kwargs: Additional parameters to pass to the tool
|
||||
|
||||
Returns:
|
||||
An instance of the specified tool
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool name is not recognized
|
||||
"""
|
||||
tool_map = {
|
||||
"embedding": (EmbeddingTool, None), # all
|
||||
|
||||
"query_formulator": (QueryFormulatorTool, "query_formulator"), # manager
|
||||
"websearch": (WebSearchTool, None), # manager
|
||||
"narrative_summarization": (NarrativeSummarizationTool, "narrative_summarization"), # manager
|
||||
"context_fusion": (ContextFusionTool, "context_fusion"), # manager
|
||||
"planner_role": (SubtaskPlannerTool, "planner_role"), # manager
|
||||
"supplement_role": (SubtaskPlannerTool, "supplement_role"), # manager
|
||||
"dag_translator": (DAGTranslatorTool, "dag_translator"), # manager
|
||||
"objective_alignment": (ObjectiveAlignmentTool, "objective_alignment"), # manager
|
||||
|
||||
"operator_role": (ActionGeneratorTool, "operator_role"), # worker
|
||||
"technician_role": (ActionGeneratorTool, "technician_role"), # worker
|
||||
"analyst_role": (ActionGeneratorTool, "analyst_role"), # worker
|
||||
"grounding": (GroundingTool, "grounding"), # worker
|
||||
"text_span": (TextSpanTool, "text_span"), # worker
|
||||
"episode_summarization": (EpisodeSummarizationTool, "episode_summarization"), # worker
|
||||
|
||||
"worker_success_role": (EvaluatorTool, "worker_success_role"), # evaluator
|
||||
"worker_stale_role": (EvaluatorTool, "worker_stale_role"), # evaluator
|
||||
"periodic_role": (EvaluatorTool, "periodic_role"), # evaluator
|
||||
"final_check_role": (EvaluatorTool, "final_check_role"), # evaluator
|
||||
}
|
||||
|
||||
if tool_name not in tool_map:
|
||||
raise ValueError(f"Unknown tool name: {tool_name}")
|
||||
|
||||
tool_class, prompt_key = tool_map[tool_name]
|
||||
|
||||
# WebSearchTool and EmbeddingTool don't need a prompt
|
||||
if tool_name == "websearch":
|
||||
return tool_class(provider, model_name, None, **kwargs)
|
||||
if tool_name == "embedding":
|
||||
return tool_class(provider, model_name, None, **kwargs)
|
||||
|
||||
return tool_class(provider, model_name, prompt_key, **kwargs)
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
"""Tool for performing web searches."""
|
||||
|
||||
def __init__(self, provider: str, model_name: str, tool_name: str):
|
||||
"""
|
||||
Initialize the web search tool.
|
||||
|
||||
Args:
|
||||
provider: API provider name (e.g., "bocha", "exa")
|
||||
model_name: Model name to use (not used for WebSearchAgent)
|
||||
tool_name: Name of the tool (used as key in prompts.json)
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
# Create WebSearchAgent instance for search
|
||||
self.engine_params = {
|
||||
"engine_type": provider,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
# Initialize WebSearchAgent
|
||||
self.search_agent = WebSearchAgent(engine_params=self.engine_params)
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]) -> Tuple[str, List[int], str]:
|
||||
"""
|
||||
Execute a web search with the given query.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the search query
|
||||
Expected to have 'str_input' key with the search query
|
||||
|
||||
Returns:
|
||||
Search results as a string
|
||||
"""
|
||||
query = tool_input.get('str_input', '')
|
||||
if not query:
|
||||
return "Error: No search query provided", [0, 0, 0], ""
|
||||
|
||||
try:
|
||||
# Get the answer from the search results
|
||||
answer, total_tokens, cost = self.search_agent.get_answer(query)
|
||||
|
||||
# Return just the answer
|
||||
return answer, total_tokens, cost # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during web search: {str(e)}")
|
||||
return f"Error: Web search failed: {str(e)}", [0, 0, 0], ""
|
||||
|
||||
|
||||
class ContextFusionTool(BaseTool):
|
||||
"""Tool for fusing multiple contexts together."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Fuse multiple contexts together.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the contexts to fuse
|
||||
Expected to have 'str_input' key with JSON-formatted contexts
|
||||
|
||||
Returns:
|
||||
Fused context as a string
|
||||
"""
|
||||
contexts = tool_input.get('str_input', '')
|
||||
if not contexts:
|
||||
return "Error: No contexts provided"
|
||||
|
||||
# Use the prompt template and LMM for context fusion
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class SubtaskPlannerTool(BaseTool):
|
||||
"""Tool for planning subtasks."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Plan subtasks for a given task.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the task description
|
||||
Expected to have 'str_input' key with the task description
|
||||
May also have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Subtask plan as a string
|
||||
"""
|
||||
task = tool_input.get('str_input', '')
|
||||
if not task:
|
||||
return "Error: No task description provided"
|
||||
|
||||
# Use the prompt template and LMM for subtask planning
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class NarrativeSummarizationTool(BaseTool):
|
||||
"""Tool for summarizing narrative memories."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Summarize narrative memories.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the narrative memory data
|
||||
Expected to have 'str_input' key with the narrative memory data
|
||||
May also have 'img_input' key with relevant images
|
||||
|
||||
Returns:
|
||||
Summarized narrative as a string
|
||||
"""
|
||||
narrative_data = tool_input.get('str_input', '')
|
||||
if not narrative_data:
|
||||
return "Error: No narrative memory data provided"
|
||||
|
||||
# Use the prompt template and LMM for narrative summarization
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class EpisodeSummarizationTool(BaseTool):
|
||||
"""Tool for summarizing episodic memories."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Summarize episodic memories.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the episodic memory data
|
||||
Expected to have 'str_input' key with the episodic memory data
|
||||
May also have 'img_input' key with relevant images
|
||||
|
||||
Returns:
|
||||
Summarized episode as a string
|
||||
"""
|
||||
episode_data = tool_input.get('str_input', '')
|
||||
if not episode_data:
|
||||
return "Error: No episodic memory data provided"
|
||||
|
||||
# Use the prompt template and LMM for episode summarization
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class TextSpanTool(BaseTool):
|
||||
"""Tool for processing text spans."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Process text spans for a given input.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the text input
|
||||
Expected to have 'str_input' key with the text content
|
||||
May also have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Processed text spans as a string
|
||||
"""
|
||||
text = tool_input.get('str_input', '')
|
||||
if not text:
|
||||
return "Error: No text content provided"
|
||||
|
||||
# Use the prompt template and LMM for text span processing
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class DAGTranslatorTool(BaseTool):
|
||||
"""Tool for translating task descriptions into a DAG (Directed Acyclic Graph) structure."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Translate task descriptions into a DAG structure.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the task description
|
||||
Expected to have 'str_input' key with the task description
|
||||
May also have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
DAG representation as a string
|
||||
"""
|
||||
task = tool_input.get('str_input', '')
|
||||
if not task:
|
||||
return "Error: No task description provided"
|
||||
|
||||
# Use the prompt template and LMM for DAG translation
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class ObjectiveAlignmentTool(BaseTool):
|
||||
"""Tool for aligning and rewriting user objective with current screen context."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Align ambiguous or high-level user objective with the current desktop screenshot context
|
||||
and output a refined objective and assumptions.
|
||||
|
||||
Args:
|
||||
tool_input: Dict with keys:
|
||||
- 'str_input': the raw user objective or context text
|
||||
- 'img_input': optional screenshot image content
|
||||
|
||||
Returns:
|
||||
Refined objective as text (ideally JSON-structured), token count, and cost string
|
||||
"""
|
||||
text = tool_input.get('str_input', '')
|
||||
if not text:
|
||||
return "Error: No objective text provided", [0, 0, 0], ""
|
||||
# Forward to LMM with the prompt template
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class TrajReflectorTool(BaseTool):
|
||||
"""Tool for reflecting on execution trajectories."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Reflect on an execution trajectory.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the trajectory
|
||||
Expected to have 'str_input' key with the trajectory
|
||||
|
||||
Returns:
|
||||
Reflection as a string
|
||||
"""
|
||||
trajectory = tool_input.get('str_input', '')
|
||||
if not trajectory:
|
||||
return "Error: No trajectory provided"
|
||||
|
||||
# Use the prompt template and LMM for trajectory reflection
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
class GroundingTool(BaseTool):
|
||||
"""Tool for grounding agent actions in the environment."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Ground agent actions in the environment.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the action and environment state
|
||||
Expected to have 'str_input' key with the action
|
||||
Expected to have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Grounded action as a string
|
||||
"""
|
||||
action = tool_input.get('str_input', '')
|
||||
screenshot = tool_input.get('img_input')
|
||||
|
||||
if not action:
|
||||
return "Error: No action provided"
|
||||
if not screenshot:
|
||||
return "Error: No screenshot provided"
|
||||
|
||||
# Use the prompt template and LMM for action grounding
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
def get_grounding_wh(self):
|
||||
"""
|
||||
Get grounding width and height based on provider and model name.
|
||||
|
||||
Returns:
|
||||
If provider is doubao and model_name contains 'ui-tars', returns two values:
|
||||
grounding_width (int): Width value (1024)
|
||||
grounding_height (int): Height value (768)
|
||||
Otherwise returns None, None
|
||||
"""
|
||||
if self.provider == "doubao" and ("ui-tars" in self.model_name or "ep-" in self.model_name):
|
||||
grounding_width = 1000
|
||||
grounding_height = 1000
|
||||
return grounding_width, grounding_height
|
||||
return None, None
|
||||
|
||||
|
||||
class EvaluatorTool(BaseTool):
|
||||
"""Tool for evaluating agent performance."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Evaluate agent performance.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the evaluation data
|
||||
Expected to have 'str_input' key with the evaluation data
|
||||
|
||||
Returns:
|
||||
Evaluation result as a string
|
||||
"""
|
||||
eval_data = tool_input.get('str_input', '')
|
||||
if not eval_data:
|
||||
return "Error: No evaluation data provided"
|
||||
|
||||
# Use the prompt template and LMM for performance evaluation
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class ActionGeneratorTool(BaseTool):
|
||||
"""Tool for generating executable actions."""
|
||||
|
||||
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
|
||||
"""
|
||||
Initialize the action generator tool.
|
||||
|
||||
Args:
|
||||
provider: API provider name
|
||||
model_name: Model name to use
|
||||
tool_name: Name of the tool (used as key in prompts.json)
|
||||
**kwargs: Additional parameters, including:
|
||||
enable_search: Whether to enable web search functionality
|
||||
search_provider: Provider for web search (defaults to "bocha")
|
||||
search_model: Model for web search (defaults to "")
|
||||
"""
|
||||
super().__init__(provider, model_name, tool_name)
|
||||
|
||||
# Extract search-related parameters
|
||||
self.enable_search = kwargs.get("enable_search", False)
|
||||
search_provider = kwargs.get("search_provider", "bocha")
|
||||
search_model = kwargs.get("search_model", "")
|
||||
|
||||
# Initialize search tool if enabled
|
||||
self.search_tool = None
|
||||
if self.enable_search:
|
||||
self.search_tool = WebSearchTool(search_provider, search_model, "")
|
||||
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Generate executable actions.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the action request
|
||||
Expected to have 'str_input' key with the action request
|
||||
May also have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Generated action as a string
|
||||
"""
|
||||
action_request = tool_input.get('str_input', '')
|
||||
if not action_request:
|
||||
return "Error: No action request provided", [0, 0, 0], ""
|
||||
|
||||
# Check if search is enabled
|
||||
if self.enable_search and self.search_tool:
|
||||
try:
|
||||
# Use the input text directly as search query
|
||||
search_query = action_request
|
||||
logger.info(f"Performing web search for query: {search_query}")
|
||||
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
|
||||
|
||||
# Enhance the action request with search results
|
||||
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
|
||||
tool_input["str_input"] = enhanced_request
|
||||
|
||||
logger.info(f"Search completed. Found information: {len(search_results)} characters")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during web search: {e}")
|
||||
# Continue with original request if search fails
|
||||
|
||||
# Use the prompt template and LMM for action generation
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
|
||||
class FastActionGeneratorTool(BaseTool):
|
||||
"""Tool for directly generating executable actions without intermediate planning."""
|
||||
|
||||
def __init__(self, provider: str, model_name: str, tool_name: str, **kwargs):
|
||||
"""
|
||||
Initialize the fast action generator tool.
|
||||
|
||||
Args:
|
||||
provider: API provider name
|
||||
model_name: Model name to use
|
||||
tool_name: Name of the tool (used as key in prompts.json)
|
||||
**kwargs: Additional parameters, including:
|
||||
enable_search: Whether to enable web search functionality
|
||||
search_provider: Provider for web search (defaults to "bocha")
|
||||
search_model: Model for web search (defaults to "")
|
||||
"""
|
||||
super().__init__(provider, model_name, tool_name)
|
||||
|
||||
# Extract search-related parameters
|
||||
self.enable_search = kwargs.get("enable_search", False)
|
||||
search_provider = kwargs.get("search_provider", "bocha")
|
||||
search_model = kwargs.get("search_model", "")
|
||||
|
||||
# Initialize search tool if enabled
|
||||
self.search_tool = None
|
||||
if self.enable_search:
|
||||
self.search_tool = WebSearchTool(search_provider, search_model, "")
|
||||
logger.info(f"Web search enabled for {tool_name} using provider: {search_provider}")
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Generate executable actions directly from the instruction and screenshot.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the action request
|
||||
Expected to have 'str_input' key with the instruction
|
||||
Expected to have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Generated action as a string, token count, and cost
|
||||
"""
|
||||
action_request = tool_input.get('str_input', '')
|
||||
screenshot = tool_input.get('img_input')
|
||||
if not action_request:
|
||||
return "Error: No action request provided", [0, 0, 0], ""
|
||||
if not screenshot:
|
||||
return "Error: No screenshot provided", [0, 0, 0], ""
|
||||
# Check if search is enabled
|
||||
if self.enable_search and self.search_tool:
|
||||
try:
|
||||
# Use the input text directly as search query
|
||||
search_query = action_request
|
||||
logger.info(f"Performing web search for query: {search_query}")
|
||||
search_results, tokens, cost = self.search_tool.execute({"str_input": search_query})
|
||||
|
||||
# Enhance the action request with search results
|
||||
enhanced_request = f"[Action Request]\n{action_request}\n[End of Action Request]\n\n[Web Search Results for '{action_request}']\n{search_results}\n\n[End of Web Search Results]"
|
||||
tool_input["str_input"] = enhanced_request
|
||||
|
||||
logger.info(f"Search completed. Found information: {len(search_results)} characters")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during web search: {e}")
|
||||
# Continue with original request if search fails
|
||||
|
||||
# Use the prompt template and LMM for action generation
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
def get_grounding_wh(self):
|
||||
"""
|
||||
Get grounding width and height based on provider and model name.
|
||||
|
||||
Returns:
|
||||
If provider is doubao and model_name contains 'ui-tars', returns two values:
|
||||
grounding_width (int): Width value (1024)
|
||||
grounding_height (int): Height value (768)
|
||||
Otherwise returns None, None
|
||||
"""
|
||||
if self.provider == "doubao" and "ui-tars" in self.model_name:
|
||||
grounding_width = 1000
|
||||
grounding_height = 1000
|
||||
return grounding_width, grounding_height
|
||||
return None, None
|
||||
|
||||
class EmbeddingTool(BaseTool):
|
||||
"""Tool for generating text embeddings."""
|
||||
|
||||
def __init__(self, provider: str, model_name: str, tool_name: str):
|
||||
"""
|
||||
Initialize the embedding tool.
|
||||
|
||||
Args:
|
||||
provider: API provider name (e.g., "openai", "gemini")
|
||||
model_name: Model name to use
|
||||
tool_name: Name of the tool (used as key in prompts.json)
|
||||
"""
|
||||
self.provider = provider
|
||||
self.model_name = model_name
|
||||
self.tool_name = tool_name
|
||||
|
||||
# Create EmbeddingAgent instance
|
||||
self.engine_params = {
|
||||
"engine_type": provider,
|
||||
"embedding_model": model_name
|
||||
}
|
||||
|
||||
# Initialize EmbeddingAgent
|
||||
self.embedding_agent = EmbeddingAgent(engine_params=self.engine_params)
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Generate embeddings for the given text.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the text to embed
|
||||
Expected to have 'str_input' key with the text
|
||||
|
||||
Returns:
|
||||
Embeddings as a JSON string
|
||||
"""
|
||||
text = tool_input.get('str_input', '')
|
||||
|
||||
if not text:
|
||||
return "Error: No text provided for embedding", [0, 0, 0], ""
|
||||
|
||||
try:
|
||||
# Get embeddings for the text
|
||||
embeddings, total_tokens, cost_string = self.embedding_agent.get_embeddings(text)
|
||||
return embeddings, total_tokens, cost_string
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during embedding operation: {str(e)}")
|
||||
return f"Error: Embedding operation failed: {str(e)}", [0, 0, 0], ""
|
||||
|
||||
class QueryFormulatorTool(BaseTool):
|
||||
"""Tool for formulating queries from tasks or contexts."""
|
||||
|
||||
def execute(self, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Formulate a query for a given task or context.
|
||||
|
||||
Args:
|
||||
tool_input: Dictionary containing the task or context description
|
||||
Expected to have 'str_input' key with the description
|
||||
May also have 'img_input' key with a screenshot
|
||||
|
||||
Returns:
|
||||
Formulated query as a string
|
||||
"""
|
||||
task = tool_input.get('str_input', '')
|
||||
if not task:
|
||||
return "Error: No task or context description provided"
|
||||
|
||||
# Use the prompt template and LMM for query formulation
|
||||
return self._call_lmm(tool_input)
|
||||
|
||||
class NewTools:
|
||||
"""Main Tools class that provides access to all available tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Tools class."""
|
||||
self.tools = {}
|
||||
|
||||
def register_tool(self, tool_name: str, provider: str, model_name: str, **kwargs):
|
||||
"""
|
||||
Register a tool with the specified parameters.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to register
|
||||
provider: API provider name
|
||||
model_name: Model name to use
|
||||
**kwargs: Additional parameters to pass to the tool
|
||||
"""
|
||||
tool: BaseTool = ToolFactory.create_tool(tool_name, provider, model_name, **kwargs)
|
||||
self.tools[tool_name] = tool
|
||||
|
||||
def execute_tool(self, tool_name: str, tool_input: Dict[str, Any]):
|
||||
"""
|
||||
Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to execute
|
||||
tool_input: Input for the tool
|
||||
|
||||
Returns:
|
||||
The output of the tool as a string
|
||||
|
||||
Raises:
|
||||
ValueError: If the tool is not registered
|
||||
"""
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool {tool_name} is not registered")
|
||||
|
||||
return self.tools[tool_name].execute(tool_input)
|
||||
|
||||
def reset(self, tool_name: Optional[str] = None):
|
||||
"""
|
||||
Reset tools by resetting their llm_agent if available.
|
||||
|
||||
Args:
|
||||
tool_name: Optional name of the specific tool to reset. If None, resets all tools.
|
||||
"""
|
||||
if tool_name is not None:
|
||||
# Reset a specific tool
|
||||
if tool_name not in self.tools:
|
||||
raise ValueError(f"Tool {tool_name} is not registered")
|
||||
|
||||
tool = self.tools[tool_name]
|
||||
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
|
||||
tool.llm_agent.reset()
|
||||
else:
|
||||
# Reset all tools
|
||||
for tool in self.tools.values():
|
||||
# Only reset if the tool has an llm_agent attribute
|
||||
if hasattr(tool, 'llm_agent') and tool.llm_agent is not None:
|
||||
tool.llm_agent.reset()
|
||||
Reference in New Issue
Block a user