CoACT initialize (#292)
This commit is contained in:
48
mm_agents/coact/autogen/tools/experimental/__init__.py
Normal file
48
mm_agents/coact/autogen/tools/experimental/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .browser_use import BrowserUseTool
|
||||
from .crawl4ai import Crawl4AITool
|
||||
from .deep_research import DeepResearchTool
|
||||
from .duckduckgo import DuckDuckGoSearchTool
|
||||
from .google_search import GoogleSearchTool, YoutubeSearchTool
|
||||
from .messageplatform import (
|
||||
DiscordRetrieveTool,
|
||||
DiscordSendTool,
|
||||
SlackRetrieveRepliesTool,
|
||||
SlackRetrieveTool,
|
||||
SlackSendTool,
|
||||
TelegramRetrieveTool,
|
||||
TelegramSendTool,
|
||||
)
|
||||
from .perplexity import PerplexitySearchTool
|
||||
from .reliable import ReliableTool, ReliableToolError, SuccessfulExecutionParameters, ToolExecutionDetails
|
||||
from .tavily import TavilySearchTool
|
||||
from .web_search_preview import WebSearchPreviewTool
|
||||
from .wikipedia import WikipediaPageLoadTool, WikipediaQueryRunTool
|
||||
|
||||
__all__ = [
|
||||
"BrowserUseTool",
|
||||
"Crawl4AITool",
|
||||
"DeepResearchTool",
|
||||
"DiscordRetrieveTool",
|
||||
"DiscordSendTool",
|
||||
"DuckDuckGoSearchTool",
|
||||
"GoogleSearchTool",
|
||||
"PerplexitySearchTool",
|
||||
"ReliableTool",
|
||||
"ReliableToolError",
|
||||
"SlackRetrieveRepliesTool",
|
||||
"SlackRetrieveTool",
|
||||
"SlackSendTool",
|
||||
"SuccessfulExecutionParameters",
|
||||
"TavilySearchTool",
|
||||
"TelegramRetrieveTool",
|
||||
"TelegramSendTool",
|
||||
"ToolExecutionDetails",
|
||||
"WebSearchPreviewTool",
|
||||
"WikipediaPageLoadTool",
|
||||
"WikipediaQueryRunTool",
|
||||
"YoutubeSearchTool",
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .browser_use import BrowserUseResult, BrowserUseTool, ExtractedContent
|
||||
|
||||
__all__ = ["BrowserUseResult", "BrowserUseTool", "ExtractedContent"]
|
||||
@@ -0,0 +1,161 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from ... import Depends, Tool
|
||||
from ...dependency_injection import on
|
||||
|
||||
with optional_import_block():
|
||||
from browser_use import Agent, Controller
|
||||
from browser_use.browser.browser import Browser, BrowserConfig
|
||||
|
||||
from ....interop.langchain.langchain_chat_model_factory import LangChainChatModelFactory
|
||||
|
||||
|
||||
__all__ = ["BrowserUseResult", "BrowserUseTool", "ExtractedContent"]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.browser_use")
|
||||
class ExtractedContent(BaseModel):
|
||||
"""Extracted content from the browser.
|
||||
|
||||
Attributes:
|
||||
content: The extracted content.
|
||||
url: The URL of the extracted content
|
||||
"""
|
||||
|
||||
content: str
|
||||
url: Optional[str]
|
||||
|
||||
@field_validator("url")
|
||||
@classmethod
|
||||
def check_url(cls, v: str) -> Optional[str]:
|
||||
"""Check if the URL is about:blank and return None if it is.
|
||||
|
||||
Args:
|
||||
v: The URL to check.
|
||||
"""
|
||||
if v == "about:blank":
|
||||
return None
|
||||
return v
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.browser_use")
|
||||
class BrowserUseResult(BaseModel):
|
||||
"""The result of using the browser to perform a task.
|
||||
|
||||
Attributes:
|
||||
extracted_content: List of extracted content.
|
||||
final_result: The final result.
|
||||
"""
|
||||
|
||||
extracted_content: list[ExtractedContent]
|
||||
final_result: Optional[str]
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"langchain_anthropic",
|
||||
"langchain_google_genai",
|
||||
"langchain_ollama",
|
||||
"langchain_openai",
|
||||
"langchain_core",
|
||||
"browser_use",
|
||||
],
|
||||
"browser-use",
|
||||
)
|
||||
@export_module("autogen.tools.experimental")
|
||||
class BrowserUseTool(Tool):
|
||||
"""BrowserUseTool is a tool that uses the browser to perform a task."""
|
||||
|
||||
def __init__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
browser: Optional["Browser"] = None,
|
||||
agent_kwargs: Optional[dict[str, Any]] = None,
|
||||
browser_config: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""Use the browser to perform a task.
|
||||
|
||||
Args:
|
||||
llm_config: The LLM configuration.
|
||||
browser: The browser to use. If defined, browser_config must be None
|
||||
agent_kwargs: Additional keyword arguments to pass to the Agent
|
||||
browser_config: The browser configuration to use. If defined, browser must be None
|
||||
"""
|
||||
if agent_kwargs is None:
|
||||
agent_kwargs = {}
|
||||
|
||||
if browser_config is None:
|
||||
browser_config = {}
|
||||
|
||||
if browser is not None and browser_config:
|
||||
raise ValueError(
|
||||
f"Cannot provide both browser and additional keyword parameters: {browser=}, {browser_config=}"
|
||||
)
|
||||
|
||||
async def browser_use( # type: ignore[no-any-unimported]
|
||||
task: Annotated[str, "The task to perform."],
|
||||
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
|
||||
browser: Annotated[Optional[Browser], Depends(on(browser))],
|
||||
agent_kwargs: Annotated[dict[str, Any], Depends(on(agent_kwargs))],
|
||||
browser_config: Annotated[dict[str, Any], Depends(on(browser_config))],
|
||||
) -> BrowserUseResult:
|
||||
agent_kwargs = agent_kwargs.copy()
|
||||
browser_config = browser_config.copy()
|
||||
if browser is None:
|
||||
# set default value for headless
|
||||
headless = browser_config.pop("headless", True)
|
||||
|
||||
browser_config = BrowserConfig(headless=headless, **browser_config)
|
||||
browser = Browser(config=browser_config)
|
||||
|
||||
# set default value for generate_gif
|
||||
if "generate_gif" not in agent_kwargs:
|
||||
agent_kwargs["generate_gif"] = False
|
||||
|
||||
llm = LangChainChatModelFactory.create_base_chat_model(llm_config)
|
||||
|
||||
max_steps = agent_kwargs.pop("max_steps", 100)
|
||||
|
||||
agent = Agent(
|
||||
task=task,
|
||||
llm=llm,
|
||||
browser=browser,
|
||||
controller=BrowserUseTool._get_controller(llm_config),
|
||||
**agent_kwargs,
|
||||
)
|
||||
|
||||
result = await agent.run(max_steps=max_steps)
|
||||
|
||||
extracted_content = [
|
||||
ExtractedContent(content=content, url=url)
|
||||
for content, url in zip(result.extracted_content(), result.urls())
|
||||
]
|
||||
return BrowserUseResult(
|
||||
extracted_content=extracted_content,
|
||||
final_result=result.final_result(),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="browser_use",
|
||||
description="Use the browser to perform a task.",
|
||||
func_or_tool=browser_use,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_controller(llm_config: Union[LLMConfig, dict[str, Any]]) -> Any:
|
||||
response_format = (
|
||||
llm_config["config_list"][0].get("response_format", None)
|
||||
if "config_list" in llm_config
|
||||
else llm_config.get("response_format")
|
||||
)
|
||||
return Controller(output_model=response_format)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .crawl4ai import Crawl4AITool
|
||||
|
||||
__all__ = ["Crawl4AITool"]
|
||||
153
mm_agents/coact/autogen/tools/experimental/crawl4ai/crawl4ai.py
Normal file
153
mm_agents/coact/autogen/tools/experimental/crawl4ai/crawl4ai.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....interop import LiteLLmConfigFactory
|
||||
from ....llm_config import LLMConfig
|
||||
from ... import Tool
|
||||
from ...dependency_injection import Depends, on
|
||||
|
||||
with optional_import_block():
|
||||
from crawl4ai import AsyncWebCrawler, BrowserConfig, CacheMode, CrawlerRunConfig
|
||||
from crawl4ai.extraction_strategy import LLMExtractionStrategy
|
||||
|
||||
__all__ = ["Crawl4AITool"]
|
||||
|
||||
|
||||
@require_optional_import(["crawl4ai"], "crawl4ai")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class Crawl4AITool(Tool):
|
||||
"""
|
||||
Crawl a website and extract information using the crawl4ai library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None,
|
||||
extraction_model: Optional[type[BaseModel]] = None,
|
||||
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Crawl4AITool.
|
||||
|
||||
Args:
|
||||
llm_config: The config dictionary for the LLM model. If None, the tool will run without LLM.
|
||||
extraction_model: The Pydantic model to use for extraction. If None, the tool will use the default schema.
|
||||
llm_strategy_kwargs: The keyword arguments to pass to the LLM extraction strategy.
|
||||
"""
|
||||
Crawl4AITool._validate_llm_strategy_kwargs(llm_strategy_kwargs, llm_config_provided=(llm_config is not None))
|
||||
|
||||
async def crawl4ai_helper( # type: ignore[no-any-unimported]
|
||||
url: str,
|
||||
browser_cfg: Optional["BrowserConfig"] = None,
|
||||
crawl_config: Optional["CrawlerRunConfig"] = None,
|
||||
) -> Any:
|
||||
async with AsyncWebCrawler(config=browser_cfg) as crawler:
|
||||
result = await crawler.arun(
|
||||
url=url,
|
||||
config=crawl_config,
|
||||
)
|
||||
|
||||
if crawl_config is None:
|
||||
response = result.markdown
|
||||
else:
|
||||
response = result.extracted_content if result.success else result.error_message
|
||||
|
||||
return response
|
||||
|
||||
async def crawl4ai_without_llm(
|
||||
url: Annotated[str, "The url to crawl and extract information from."],
|
||||
) -> Any:
|
||||
return await crawl4ai_helper(url=url)
|
||||
|
||||
async def crawl4ai_with_llm(
|
||||
url: Annotated[str, "The url to crawl and extract information from."],
|
||||
instruction: Annotated[str, "The instruction to provide on how and what to extract."],
|
||||
llm_config: Annotated[Any, Depends(on(llm_config))],
|
||||
llm_strategy_kwargs: Annotated[Optional[dict[str, Any]], Depends(on(llm_strategy_kwargs))],
|
||||
extraction_model: Annotated[Optional[type[BaseModel]], Depends(on(extraction_model))],
|
||||
) -> Any:
|
||||
browser_cfg = BrowserConfig(headless=True)
|
||||
crawl_config = Crawl4AITool._get_crawl_config(
|
||||
llm_config=llm_config,
|
||||
instruction=instruction,
|
||||
extraction_model=extraction_model,
|
||||
llm_strategy_kwargs=llm_strategy_kwargs,
|
||||
)
|
||||
|
||||
return await crawl4ai_helper(url=url, browser_cfg=browser_cfg, crawl_config=crawl_config)
|
||||
|
||||
super().__init__(
|
||||
name="crawl4ai",
|
||||
description="Crawl a website and extract information.",
|
||||
func_or_tool=crawl4ai_without_llm if llm_config is None else crawl4ai_with_llm,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_llm_strategy_kwargs(llm_strategy_kwargs: Optional[dict[str, Any]], llm_config_provided: bool) -> None:
|
||||
if not llm_strategy_kwargs:
|
||||
return
|
||||
|
||||
if not llm_config_provided:
|
||||
raise ValueError("llm_strategy_kwargs can only be provided if llm_config is also provided.")
|
||||
|
||||
check_parameters_error_msg = "".join(
|
||||
f"'{key}' should not be provided in llm_strategy_kwargs. It is automatically set based on llm_config.\n"
|
||||
for key in ["provider", "api_token"]
|
||||
if key in llm_strategy_kwargs
|
||||
)
|
||||
|
||||
check_parameters_error_msg += "".join(
|
||||
"'schema' should not be provided in llm_strategy_kwargs. It is automatically set based on extraction_model type.\n"
|
||||
if "schema" in llm_strategy_kwargs
|
||||
else ""
|
||||
)
|
||||
|
||||
check_parameters_error_msg += "".join(
|
||||
"'instruction' should not be provided in llm_strategy_kwargs. It is provided at the time of calling the tool.\n"
|
||||
if "instruction" in llm_strategy_kwargs
|
||||
else ""
|
||||
)
|
||||
|
||||
if check_parameters_error_msg:
|
||||
raise ValueError(check_parameters_error_msg)
|
||||
|
||||
@staticmethod
|
||||
def _get_crawl_config( # type: ignore[no-any-unimported]
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
instruction: str,
|
||||
llm_strategy_kwargs: Optional[dict[str, Any]] = None,
|
||||
extraction_model: Optional[type[BaseModel]] = None,
|
||||
) -> "CrawlerRunConfig":
|
||||
lite_llm_config = LiteLLmConfigFactory.create_lite_llm_config(llm_config)
|
||||
|
||||
if llm_strategy_kwargs is None:
|
||||
llm_strategy_kwargs = {}
|
||||
|
||||
schema = (
|
||||
extraction_model.model_json_schema()
|
||||
if (extraction_model and issubclass(extraction_model, BaseModel))
|
||||
else None
|
||||
)
|
||||
|
||||
extraction_type = llm_strategy_kwargs.pop("extraction_type", "schema" if schema else "block")
|
||||
|
||||
# 1. Define the LLM extraction strategy
|
||||
llm_strategy = LLMExtractionStrategy(
|
||||
**lite_llm_config,
|
||||
schema=schema,
|
||||
extraction_type=extraction_type,
|
||||
instruction=instruction,
|
||||
**llm_strategy_kwargs,
|
||||
)
|
||||
|
||||
# 2. Build the crawler config
|
||||
crawl_config = CrawlerRunConfig(extraction_strategy=llm_strategy, cache_mode=CacheMode.BYPASS)
|
||||
|
||||
return crawl_config
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .deep_research import DeepResearchTool
|
||||
|
||||
__all__ = ["DeepResearchTool"]
|
||||
@@ -0,0 +1,328 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from typing import Annotated, Any, Callable, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ....agentchat import ConversableAgent
|
||||
from ....doc_utils import export_module
|
||||
from ....llm_config import LLMConfig
|
||||
from ... import Depends, Tool
|
||||
from ...dependency_injection import on
|
||||
|
||||
__all__ = ["DeepResearchTool"]
|
||||
|
||||
|
||||
class Subquestion(BaseModel):
|
||||
question: Annotated[str, Field(description="The original question.")]
|
||||
|
||||
def format(self) -> str:
|
||||
return f"Question: {self.question}\n"
|
||||
|
||||
|
||||
class SubquestionAnswer(Subquestion):
|
||||
answer: Annotated[str, Field(description="The answer to the question.")]
|
||||
|
||||
def format(self) -> str:
|
||||
return f"Question: {self.question}\n{self.answer}\n"
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
question: Annotated[str, Field(description="The original question.")]
|
||||
subquestions: Annotated[list[Subquestion], Field(description="The subquestions that need to be answered.")]
|
||||
|
||||
def format(self) -> str:
|
||||
return f"Task: {self.question}\n\n" + "\n".join(
|
||||
"Subquestion " + str(i + 1) + ":\n" + subquestion.format()
|
||||
for i, subquestion in enumerate(self.subquestions)
|
||||
)
|
||||
|
||||
|
||||
class CompletedTask(BaseModel):
|
||||
question: Annotated[str, Field(description="The original question.")]
|
||||
subquestions: Annotated[list[SubquestionAnswer], Field(description="The subquestions and their answers")]
|
||||
|
||||
def format(self) -> str:
|
||||
return f"Task: {self.question}\n\n" + "\n".join(
|
||||
"Subquestion " + str(i + 1) + ":\n" + subquestion.format()
|
||||
for i, subquestion in enumerate(self.subquestions)
|
||||
)
|
||||
|
||||
|
||||
class InformationCrumb(BaseModel):
|
||||
source_url: str
|
||||
source_title: str
|
||||
source_summary: str
|
||||
relevant_info: str
|
||||
|
||||
|
||||
class GatheredInformation(BaseModel):
|
||||
information: list[InformationCrumb]
|
||||
|
||||
def format(self) -> str:
|
||||
return "Here is the gathered information: \n" + "\n".join(
|
||||
f"URL: {info.source_url}\nTitle: {info.source_title}\nSummary: {info.source_summary}\nRelevant Information: {info.relevant_info}\n\n"
|
||||
for info in self.information
|
||||
)
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental")
|
||||
class DeepResearchTool(Tool):
|
||||
"""A tool that delegates a web research task to the subteams of agents."""
|
||||
|
||||
ANSWER_CONFIRMED_PREFIX = "Answer confirmed:"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
max_web_steps: int = 30,
|
||||
):
|
||||
"""Initialize the DeepResearchTool.
|
||||
|
||||
Args:
|
||||
llm_config (LLMConfig, dict[str, Any]): The LLM configuration.
|
||||
max_web_steps (int, optional): The maximum number of web steps. Defaults to 30.
|
||||
"""
|
||||
self.llm_config = llm_config
|
||||
|
||||
self.summarizer_agent = ConversableAgent(
|
||||
name="SummarizerAgent",
|
||||
system_message=(
|
||||
"You are an agent with a task of answering the question provided by the user."
|
||||
"First you need to split the question into subquestions by calling the 'split_question_and_answer_subquestions' method."
|
||||
"Then you need to sintesize the answers the original question by combining the answers to the subquestions."
|
||||
),
|
||||
is_termination_msg=lambda x: x.get("content", "")
|
||||
and x.get("content", "").startswith(self.ANSWER_CONFIRMED_PREFIX),
|
||||
llm_config=llm_config,
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
self.critic_agent = ConversableAgent(
|
||||
name="CriticAgent",
|
||||
system_message=(
|
||||
"You are a critic agent responsible for evaluating the answer provided by the summarizer agent.\n"
|
||||
"Your task is to assess the quality of the answer based on its coherence, relevance, and completeness.\n"
|
||||
"Provide constructive feedback on how the answer can be improved.\n"
|
||||
"If the answer is satisfactory, call the 'confirm_answer' method to end the task.\n"
|
||||
),
|
||||
is_termination_msg=lambda x: x.get("content", "")
|
||||
and x.get("content", "").startswith(self.ANSWER_CONFIRMED_PREFIX),
|
||||
llm_config=llm_config,
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
def delegate_research_task(
|
||||
task: Annotated[str, "The task to perform a research on."],
|
||||
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
|
||||
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
|
||||
) -> str:
|
||||
"""Delegate a research task to the agent.
|
||||
|
||||
Args:
|
||||
task (str): The task to perform a research on.
|
||||
llm_config (LLMConfig, dict[str, Any]): The LLM configuration.
|
||||
max_web_steps (int): The maximum number of web steps.
|
||||
|
||||
Returns:
|
||||
str: The answer to the research task.
|
||||
"""
|
||||
|
||||
@self.summarizer_agent.register_for_execution()
|
||||
@self.critic_agent.register_for_llm(description="Call this method to confirm the final answer.")
|
||||
def confirm_summary(answer: str, reasoning: str) -> str:
|
||||
return f"{self.ANSWER_CONFIRMED_PREFIX}" + answer + "\nReasoning: " + reasoning
|
||||
|
||||
split_question_and_answer_subquestions = DeepResearchTool._get_split_question_and_answer_subquestions(
|
||||
llm_config=llm_config,
|
||||
max_web_steps=max_web_steps,
|
||||
)
|
||||
|
||||
self.summarizer_agent.register_for_llm(description="Split the question into subquestions and get answers.")(
|
||||
split_question_and_answer_subquestions
|
||||
)
|
||||
self.critic_agent.register_for_execution()(split_question_and_answer_subquestions)
|
||||
|
||||
result = self.critic_agent.initiate_chat(
|
||||
self.summarizer_agent,
|
||||
message="Please answer the following question: " + task,
|
||||
# This outer chat should preserve the history of the conversation
|
||||
clear_history=False,
|
||||
)
|
||||
|
||||
return result.summary
|
||||
|
||||
super().__init__(
|
||||
name=delegate_research_task.__name__,
|
||||
description="Delegate a research task to the deep research agent.",
|
||||
func_or_tool=delegate_research_task,
|
||||
)
|
||||
|
||||
SUBQUESTIONS_ANSWER_PREFIX = "Subquestions answered:"
|
||||
|
||||
@staticmethod
|
||||
def _get_split_question_and_answer_subquestions(
|
||||
llm_config: Union[LLMConfig, dict[str, Any]], max_web_steps: int
|
||||
) -> Callable[..., Any]:
|
||||
def split_question_and_answer_subquestions(
|
||||
question: Annotated[str, "The question to split and answer."],
|
||||
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
|
||||
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
|
||||
) -> str:
|
||||
decomposition_agent = ConversableAgent(
|
||||
name="DecompositionAgent",
|
||||
system_message=(
|
||||
"You are an expert at breaking down complex questions into smaller, focused subquestions.\n"
|
||||
"Your task is to take any question provided and divide it into clear, actionable subquestions that can be individually answered.\n"
|
||||
"Ensure the subquestions are logical, non-redundant, and cover all key aspects of the original question.\n"
|
||||
"Avoid providing answers or interpretations—focus solely on decomposition.\n"
|
||||
"Do not include banal, general knowledge questions\n"
|
||||
"Do not include questions that go into unnecessary detail that is not relevant to the original question\n"
|
||||
"Do not include question that require knowledge of the original or other subquestions to answer\n"
|
||||
"Some rule of thumb is to have only one subquestion for easy questions, 3 for medium questions, and 5 for hard questions.\n"
|
||||
),
|
||||
llm_config=llm_config,
|
||||
is_termination_msg=lambda x: x.get("content", "")
|
||||
and x.get("content", "").startswith(DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX),
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
example_task = Task(
|
||||
question="What is the capital of France?",
|
||||
subquestions=[Subquestion(question="What is the capital of France?")],
|
||||
)
|
||||
decomposition_critic = ConversableAgent(
|
||||
name="DecompositionCritic",
|
||||
system_message=(
|
||||
"You are a critic agent responsible for evaluating the subquestions provided by the initial analysis agent.\n"
|
||||
"You need to confirm whether the subquestions are clear, actionable, and cover all key aspects of the original question.\n"
|
||||
"Do not accept redundant or unnecessary subquestions, focus solely on the minimal viable subset of subqestions necessary to answer the original question. \n"
|
||||
"Do not accept banal, general knowledge questions\n"
|
||||
"Do not accept questions that go into unnecessary detail that is not relevant to the original question\n"
|
||||
"Remove questions that can be answered with combining knowledge from other questions\n"
|
||||
"After you are satisfied with the subquestions, call the 'generate_subquestions' method to answer each subquestion.\n"
|
||||
"This is an example of an argument that can be passed to the 'generate_subquestions' method:\n"
|
||||
f"{{'task': {example_task.model_dump()}}}\n"
|
||||
"Some rule of thumb is to have only one subquestion for easy questions, 3 for medium questions, and 5 for hard questions.\n"
|
||||
),
|
||||
llm_config=llm_config,
|
||||
is_termination_msg=lambda x: x.get("content", "")
|
||||
and x.get("content", "").startswith(DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX),
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
generate_subquestions = DeepResearchTool._get_generate_subquestions(
|
||||
llm_config=llm_config, max_web_steps=max_web_steps
|
||||
)
|
||||
decomposition_agent.register_for_execution()(generate_subquestions)
|
||||
decomposition_critic.register_for_llm(description="Generate subquestions for a task.")(
|
||||
generate_subquestions
|
||||
)
|
||||
|
||||
result = decomposition_critic.initiate_chat(
|
||||
decomposition_agent,
|
||||
message="Analyse and gather subqestions for the following question: " + question,
|
||||
)
|
||||
|
||||
return result.summary
|
||||
|
||||
return split_question_and_answer_subquestions
|
||||
|
||||
@staticmethod
|
||||
def _get_generate_subquestions(
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
max_web_steps: int,
|
||||
) -> Callable[..., str]:
|
||||
"""Get the generate_subquestions method.
|
||||
|
||||
Args:
|
||||
llm_config (Union[LLMConfig, dict[str, Any]]): The LLM configuration.
|
||||
max_web_steps (int): The maximum number of web steps.
|
||||
|
||||
Returns:
|
||||
Callable[..., str]: The generate_subquestions method.
|
||||
"""
|
||||
|
||||
def generate_subquestions(
|
||||
task: Task,
|
||||
llm_config: Annotated[Union[LLMConfig, dict[str, Any]], Depends(on(llm_config))],
|
||||
max_web_steps: Annotated[int, Depends(on(max_web_steps))],
|
||||
) -> str:
|
||||
if not task.subquestions:
|
||||
task.subquestions = [Subquestion(question=task.question)]
|
||||
|
||||
subquestions_answers: list[SubquestionAnswer] = []
|
||||
for subquestion in task.subquestions:
|
||||
answer = DeepResearchTool._answer_question(
|
||||
subquestion.question, llm_config=llm_config, max_web_steps=max_web_steps
|
||||
)
|
||||
subquestions_answers.append(SubquestionAnswer(question=subquestion.question, answer=answer))
|
||||
|
||||
completed_task = CompletedTask(question=task.question, subquestions=subquestions_answers)
|
||||
|
||||
return f"{DeepResearchTool.SUBQUESTIONS_ANSWER_PREFIX} \n" + completed_task.format()
|
||||
|
||||
return generate_subquestions
|
||||
|
||||
@staticmethod
|
||||
def _answer_question(
|
||||
question: str,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
max_web_steps: int,
|
||||
) -> str:
|
||||
from ....agents.experimental.websurfer import WebSurferAgent
|
||||
|
||||
websurfer_config = copy.deepcopy(llm_config)
|
||||
|
||||
websurfer_config["config_list"][0]["response_format"] = GatheredInformation
|
||||
|
||||
def is_termination_msg(x: dict[str, Any]) -> bool:
|
||||
content = x.get("content", "")
|
||||
return (content is not None) and content.startswith(DeepResearchTool.ANSWER_CONFIRMED_PREFIX)
|
||||
|
||||
websurfer_agent = WebSurferAgent(
|
||||
llm_config=llm_config,
|
||||
web_tool_llm_config=websurfer_config,
|
||||
name="WebSurferAgent",
|
||||
system_message=(
|
||||
"You are a web surfer agent responsible for gathering information from the web to provide information for answering a question\n"
|
||||
"You will be asked to find information related to the question and provide a summary of the information gathered.\n"
|
||||
"The summary should include the URL, title, summary, and relevant information for each piece of information gathered.\n"
|
||||
),
|
||||
is_termination_msg=is_termination_msg,
|
||||
human_input_mode="NEVER",
|
||||
web_tool_kwargs={
|
||||
"agent_kwargs": {"max_steps": max_web_steps},
|
||||
},
|
||||
)
|
||||
|
||||
websurfer_critic = ConversableAgent(
|
||||
name="WebSurferCritic",
|
||||
system_message=(
|
||||
"You are a critic agent responsible for evaluating the answer provided by the web surfer agent.\n"
|
||||
"You need to confirm whether the information provided by the websurfer is correct and sufficient to answer the question.\n"
|
||||
"You can ask the web surfer to provide more information or provide and confirm the answer.\n"
|
||||
),
|
||||
llm_config=llm_config,
|
||||
is_termination_msg=is_termination_msg,
|
||||
human_input_mode="NEVER",
|
||||
)
|
||||
|
||||
@websurfer_agent.register_for_execution()
|
||||
@websurfer_critic.register_for_llm(
|
||||
description="Call this method when you agree that the original question can be answered with the gathered information and provide the answer."
|
||||
)
|
||||
def confirm_answer(answer: str) -> str:
|
||||
return f"{DeepResearchTool.ANSWER_CONFIRMED_PREFIX} " + answer
|
||||
|
||||
websurfer_critic.register_for_execution()(websurfer_agent.tool)
|
||||
|
||||
result = websurfer_critic.initiate_chat(
|
||||
websurfer_agent,
|
||||
message="Please find the answer to this question: " + question,
|
||||
)
|
||||
|
||||
return result.summary
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .duckduckgo_search import DuckDuckGoSearchTool
|
||||
|
||||
__all__ = ["DuckDuckGoSearchTool"]
|
||||
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Annotated, Any
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ... import Tool
|
||||
|
||||
with optional_import_block():
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"duckduckgo_search",
|
||||
],
|
||||
"duckduckgo_search",
|
||||
)
|
||||
def _execute_duckduckgo_query(
|
||||
query: str,
|
||||
num_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Execute a search query using the DuckDuckGo Search API.
|
||||
|
||||
Args:
|
||||
query (str): The search query string.
|
||||
num_results (int, optional): The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of search results from the DuckDuckGo API.
|
||||
"""
|
||||
with DDGS() as ddgs:
|
||||
try:
|
||||
# region='wt-wt' means worldwide
|
||||
results = list(ddgs.text(query, region="wt-wt", max_results=num_results))
|
||||
except Exception as e:
|
||||
print(f"DuckDuckGo Search failed: {e}")
|
||||
results = []
|
||||
return results
|
||||
|
||||
|
||||
def _duckduckgo_search(
|
||||
query: str,
|
||||
num_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform a DuckDuckGo search and format the results.
|
||||
|
||||
This function takes search parameters, executes the query using `_execute_duckduckgo_query`,
|
||||
and formats the results into a list of dictionaries containing title, link, and snippet.
|
||||
|
||||
Args:
|
||||
query (str): The search query string.
|
||||
num_results (int, optional): The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, where each dictionary represents a search result
|
||||
with keys 'title', 'link', and 'snippet'. Returns an empty list if no results are found.
|
||||
"""
|
||||
res = _execute_duckduckgo_query(
|
||||
query=query,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
return [
|
||||
{"title": item.get("title", ""), "link": item.get("href", ""), "snippet": item.get("body", "")} for item in res
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental")
|
||||
class DuckDuckGoSearchTool(Tool):
|
||||
"""
|
||||
DuckDuckGoSearchTool is a tool that uses DuckDuckGo to perform a search.
|
||||
|
||||
This tool allows agents to leverage the DuckDuckGo search engine for information retrieval.
|
||||
DuckDuckGo does not require an API key, making it easy to use.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initializes the DuckDuckGoSearchTool.
|
||||
"""
|
||||
|
||||
def duckduckgo_search(
|
||||
query: Annotated[str, "The search query."],
|
||||
num_results: Annotated[int, "The number of results to return."] = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Performs a search using the DuckDuckGo Search API and returns formatted results.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
num_results: The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each containing 'title', 'link', and 'snippet' of a search result.
|
||||
"""
|
||||
return _duckduckgo_search(
|
||||
query=query,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="duckduckgo_search",
|
||||
description="Use the DuckDuckGo Search API to perform a search.",
|
||||
func_or_tool=duckduckgo_search,
|
||||
)
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .authentication import GoogleCredentialsLocalProvider, GoogleCredentialsProvider
|
||||
from .drive import GoogleDriveToolkit
|
||||
from .toolkit_protocol import GoogleToolkitProtocol
|
||||
|
||||
__all__ = [
|
||||
"GoogleCredentialsLocalProvider",
|
||||
"GoogleCredentialsProvider",
|
||||
"GoogleDriveToolkit",
|
||||
"GoogleToolkitProtocol",
|
||||
]
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .credentials_local_provider import GoogleCredentialsLocalProvider
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
__all__ = [
|
||||
"GoogleCredentialsLocalProvider",
|
||||
"GoogleCredentialsProvider",
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
__all__ = ["GoogleCredenentialsHostedProvider"]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredenentialsHostedProvider(GoogleCredentialsProvider):
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int = 8080,
|
||||
*,
|
||||
kwargs: dict[str, str],
|
||||
) -> None:
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._kwargs = kwargs
|
||||
|
||||
raise NotImplementedError("This class is not implemented yet.")
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The host from which to get the credentials."""
|
||||
return self._host
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
return self._port
|
||||
|
||||
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
raise NotImplementedError("This class is not implemented yet.")
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from .credentials_provider import GoogleCredentialsProvider
|
||||
|
||||
with optional_import_block():
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
|
||||
|
||||
__all__ = ["GoogleCredentialsLocalProvider"]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredentialsLocalProvider(GoogleCredentialsProvider):
|
||||
def __init__(
|
||||
self,
|
||||
client_secret_file: str,
|
||||
scopes: list[str], # e.g. ['https://www.googleapis.com/auth/drive/readonly']
|
||||
token_file: Optional[str] = None,
|
||||
port: int = 8080,
|
||||
) -> None:
|
||||
"""A Google credentials provider that gets the credentials locally.
|
||||
|
||||
Args:
|
||||
client_secret_file (str): The path to the client secret file.
|
||||
scopes (list[str]): The scopes to request.
|
||||
token_file (str): Optional path to the token file. If not provided, the token will not be saved.
|
||||
port (int): The port from which to get the credentials.
|
||||
"""
|
||||
self.client_secret_file = client_secret_file
|
||||
self.scopes = scopes
|
||||
self.token_file = token_file
|
||||
self._port = port
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""Localhost is the default host."""
|
||||
return "localhost"
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
return self._port
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"google_auth_httplib2",
|
||||
"google_auth_oauthlib",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def _refresh_or_get_new_credentials(self, creds: Optional["Credentials"]) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
creds.refresh(Request()) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(self.client_secret_file, self.scopes)
|
||||
creds = flow.run_local_server(host=self.host, port=self.port)
|
||||
return creds # type: ignore[return-value]
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"google_auth_httplib2",
|
||||
"google_auth_oauthlib",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def get_credentials(self) -> "Credentials": # type: ignore[no-any-unimported]
|
||||
"""Get the Google credentials."""
|
||||
creds = None
|
||||
if self.token_file and os.path.exists(self.token_file):
|
||||
creds = Credentials.from_authorized_user_file(self.token_file) # type: ignore[no-untyped-call]
|
||||
|
||||
# If there are no (valid) credentials available, let the user log in.
|
||||
if not creds or not creds.valid:
|
||||
creds = self._refresh_or_get_new_credentials(creds)
|
||||
|
||||
if self.token_file:
|
||||
# Save the credentials for the next run
|
||||
with open(self.token_file, "w") as token:
|
||||
token.write(creds.to_json()) # type: ignore[no-untyped-call]
|
||||
|
||||
return creds # type: ignore[no-any-return]
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
from typing import Optional, Protocol, runtime_checkable
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
|
||||
__all__ = ["GoogleCredentialsProvider"]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@export_module("autogen.tools.experimental.google.authentication")
|
||||
class GoogleCredentialsProvider(Protocol):
|
||||
"""A protocol for Google credentials provider."""
|
||||
|
||||
def get_credentials(self) -> Optional["Credentials"]: # type: ignore[no-any-unimported]
|
||||
"""Get the Google credentials."""
|
||||
...
|
||||
|
||||
@property
|
||||
def host(self) -> str:
|
||||
"""The host from which to get the credentials."""
|
||||
...
|
||||
|
||||
@property
|
||||
def port(self) -> int:
|
||||
"""The port from which to get the credentials."""
|
||||
...
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .toolkit import GoogleDriveToolkit
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveToolkit",
|
||||
]
|
||||
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from ..model import GoogleFileInfo
|
||||
|
||||
with optional_import_block():
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
|
||||
|
||||
__all__ = [
|
||||
"download_file",
|
||||
"list_files_and_folders",
|
||||
]
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"googleapiclient",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def list_files_and_folders(service: Any, page_size: int, folder_id: Optional[str]) -> list[GoogleFileInfo]:
|
||||
kwargs = {
|
||||
"pageSize": page_size,
|
||||
"fields": "nextPageToken, files(id, name, mimeType)",
|
||||
}
|
||||
if folder_id:
|
||||
kwargs["q"] = f"'{folder_id}' in parents and trashed=false" # Search for files in the folder
|
||||
response = service.files().list(**kwargs).execute()
|
||||
result = response.get("files", [])
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(f"Expected a list of files, but got {result}")
|
||||
result = [GoogleFileInfo(**file_info) for file_info in result]
|
||||
return result
|
||||
|
||||
|
||||
def _get_file_extension(mime_type: str) -> Optional[str]:
|
||||
"""Returns the correct file extension for a given MIME type."""
|
||||
mime_extensions = {
|
||||
"application/vnd.google-apps.document": "docx", # Google Docs → Word
|
||||
"application/vnd.google-apps.spreadsheet": "csv", # Google Sheets → CSV
|
||||
"application/vnd.google-apps.presentation": "pptx", # Google Slides → PowerPoint
|
||||
"video/quicktime": "mov",
|
||||
"application/vnd.google.colaboratory": "ipynb",
|
||||
"application/pdf": "pdf",
|
||||
"image/jpeg": "jpg",
|
||||
"image/png": "png",
|
||||
"text/plain": "txt",
|
||||
"application/zip": "zip",
|
||||
}
|
||||
|
||||
return mime_extensions.get(mime_type)
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"googleapiclient",
|
||||
],
|
||||
"google-api",
|
||||
)
|
||||
def download_file(
|
||||
service: Any,
|
||||
file_id: str,
|
||||
file_name: str,
|
||||
mime_type: str,
|
||||
download_folder: Path,
|
||||
subfolder_path: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download or export file based on its MIME type, optionally saving to a subfolder."""
|
||||
file_extension = _get_file_extension(mime_type)
|
||||
if file_extension and (not file_name.lower().endswith(file_extension.lower())):
|
||||
file_name = f"{file_name}.{file_extension}"
|
||||
|
||||
# Define export formats for Google Docs, Sheets, and Slides
|
||||
export_mime_types = {
|
||||
"application/vnd.google-apps.document": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # Google Docs → Word
|
||||
"application/vnd.google-apps.spreadsheet": "text/csv", # Google Sheets → CSV
|
||||
"application/vnd.google-apps.presentation": "application/vnd.openxmlformats-officedocument.presentationml.presentation", # Google Slides → PowerPoint
|
||||
}
|
||||
|
||||
# Google Docs, Sheets, and Slides cannot be downloaded directly using service.files().get_media() because they are Google-native files
|
||||
if mime_type in export_mime_types:
|
||||
request = service.files().export(fileId=file_id, mimeType=export_mime_types[mime_type])
|
||||
else:
|
||||
# Download normal files (videos, images, etc.)
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
|
||||
# Determine the final destination directory
|
||||
destination_dir = download_folder
|
||||
if subfolder_path:
|
||||
destination_dir = download_folder / subfolder_path
|
||||
# Ensure the subfolder exists, create it if necessary
|
||||
destination_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Construct the full path for the file
|
||||
file_path = destination_dir / file_name
|
||||
|
||||
# Save file
|
||||
try:
|
||||
with io.BytesIO() as buffer:
|
||||
downloader = MediaIoBaseDownload(buffer, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
buffer.seek(0)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(buffer.getvalue())
|
||||
|
||||
# Print out the relative path of the downloaded file
|
||||
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
|
||||
return f"✅ Downloaded: {relative_path}"
|
||||
|
||||
except Exception as e:
|
||||
# Error message if unable to download
|
||||
relative_path = Path(subfolder_path) / file_name if subfolder_path else Path(file_name)
|
||||
return f"❌ FAILED to download {relative_path}: {e}"
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block
|
||||
from .... import Toolkit, tool
|
||||
from ..model import GoogleFileInfo
|
||||
from ..toolkit_protocol import GoogleToolkitProtocol
|
||||
from .drive_functions import download_file, list_files_and_folders
|
||||
|
||||
with optional_import_block():
|
||||
from google.oauth2.credentials import Credentials
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
__all__ = [
|
||||
"GoogleDriveToolkit",
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental.google.drive")
|
||||
class GoogleDriveToolkit(Toolkit, GoogleToolkitProtocol):
|
||||
"""A tool map for Google Drive."""
|
||||
|
||||
def __init__( # type: ignore[no-any-unimported]
|
||||
self,
|
||||
*,
|
||||
credentials: "Credentials",
|
||||
download_folder: Union[Path, str],
|
||||
exclude: Optional[list[Literal["list_drive_files_and_folders", "download_file_from_drive"]]] = None,
|
||||
api_version: str = "v3",
|
||||
) -> None:
|
||||
"""Initialize the Google Drive tool map.
|
||||
|
||||
Args:
|
||||
credentials: The Google OAuth2 credentials.
|
||||
download_folder: The folder to download files to.
|
||||
exclude: The tool names to exclude.
|
||||
api_version: The Google Drive API version to use."
|
||||
"""
|
||||
self.service = build(serviceName="drive", version=api_version, credentials=credentials)
|
||||
|
||||
if isinstance(download_folder, str):
|
||||
download_folder = Path(download_folder)
|
||||
download_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@tool(description="List files and folders in a Google Drive")
|
||||
def list_drive_files_and_folders(
|
||||
page_size: Annotated[int, "The number of files to list per page."] = 10,
|
||||
folder_id: Annotated[
|
||||
Optional[str],
|
||||
"The ID of the folder to list files from. If not provided, lists all files in the root folder.",
|
||||
] = None,
|
||||
) -> list[GoogleFileInfo]:
|
||||
return list_files_and_folders(service=self.service, page_size=page_size, folder_id=folder_id)
|
||||
|
||||
@tool(description="download a file from Google Drive")
|
||||
def download_file_from_drive(
|
||||
file_info: Annotated[GoogleFileInfo, "The file info to download."],
|
||||
subfolder_path: Annotated[
|
||||
Optional[str],
|
||||
"The subfolder path to save the file in. If not provided, saves in the main download folder.",
|
||||
] = None,
|
||||
) -> str:
|
||||
return download_file(
|
||||
service=self.service,
|
||||
file_id=file_info.id,
|
||||
file_name=file_info.name,
|
||||
mime_type=file_info.mime_type,
|
||||
download_folder=download_folder,
|
||||
subfolder_path=subfolder_path,
|
||||
)
|
||||
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
|
||||
tools = [tool for tool in [list_drive_files_and_folders, download_file_from_drive] if tool.name not in exclude]
|
||||
super().__init__(tools=tools)
|
||||
|
||||
@classmethod
|
||||
def recommended_scopes(cls) -> list[str]:
|
||||
"""Return the recommended scopes manatory for using tools from this tool map."""
|
||||
return [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
]
|
||||
17
mm_agents/coact/autogen/tools/experimental/google/model.py
Normal file
17
mm_agents/coact/autogen/tools/experimental/google/model.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
"GoogleFileInfo",
|
||||
]
|
||||
|
||||
|
||||
class GoogleFileInfo(BaseModel):
|
||||
name: Annotated[str, Field(description="The name of the file.")]
|
||||
id: Annotated[str, Field(description="The ID of the file.")]
|
||||
mime_type: Annotated[str, Field(alias="mimeType", description="The MIME type of the file.")]
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
__all__ = [
|
||||
"GoogleToolkitProtocol",
|
||||
]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class GoogleToolkitProtocol(Protocol):
|
||||
"""A protocol for Google tool maps."""
|
||||
|
||||
@classmethod
|
||||
def recommended_scopes(cls) -> list[str]:
|
||||
"""Defines a required static method without implementation."""
|
||||
...
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .google_search import GoogleSearchTool
|
||||
from .youtube_search import YoutubeSearchTool
|
||||
|
||||
__all__ = ["GoogleSearchTool", "YoutubeSearchTool"]
|
||||
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ... import Depends, Tool
|
||||
from ...dependency_injection import on
|
||||
|
||||
with optional_import_block():
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"googleapiclient",
|
||||
],
|
||||
"google-search",
|
||||
)
|
||||
def _execute_query(query: str, search_api_key: str, search_engine_id: str, num_results: int) -> Any:
|
||||
service = build("customsearch", "v1", developerKey=search_api_key)
|
||||
return service.cse().list(q=query, cx=search_engine_id, num=num_results).execute()
|
||||
|
||||
|
||||
def _google_search(
|
||||
query: str,
|
||||
search_api_key: str,
|
||||
search_engine_id: str,
|
||||
num_results: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
res = _execute_query(
|
||||
query=query, search_api_key=search_api_key, search_engine_id=search_engine_id, num_results=num_results
|
||||
)
|
||||
|
||||
return [
|
||||
{"title": item.get("title", ""), "link": item.get("link", ""), "snippet": item.get("snippet", "")}
|
||||
for item in res.get("items", [])
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental")
|
||||
class GoogleSearchTool(Tool):
|
||||
"""GoogleSearchTool is a tool that uses the Google Search API to perform a search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
search_api_key: Optional[str] = None,
|
||||
search_engine_id: Optional[str] = None,
|
||||
use_internal_llm_tool_if_available: bool = True,
|
||||
):
|
||||
"""GoogleSearchTool is a tool that uses the Google Search API to perform a search.
|
||||
|
||||
Args:
|
||||
search_api_key: The API key for the Google Search API.
|
||||
search_engine_id: The search engine ID for the Google Search API.
|
||||
use_internal_llm_tool_if_available: Whether to use the predefined (e.g. Gemini GenaAI) search tool. Currently, this can only be used for agents with the Gemini (GenAI) configuration.
|
||||
"""
|
||||
self.search_api_key = search_api_key
|
||||
self.search_engine_id = search_engine_id
|
||||
self.use_internal_llm_tool_if_available = use_internal_llm_tool_if_available
|
||||
|
||||
if not use_internal_llm_tool_if_available and (search_api_key is None or search_engine_id is None):
|
||||
raise ValueError(
|
||||
"search_api_key and search_engine_id must be provided if use_internal_llm_tool_if_available is False"
|
||||
)
|
||||
|
||||
if use_internal_llm_tool_if_available and (search_api_key is not None or search_engine_id is not None):
|
||||
logging.warning("search_api_key and search_engine_id will be ignored if internal LLM tool is available")
|
||||
|
||||
def google_search(
|
||||
query: Annotated[str, "The search query."],
|
||||
search_api_key: Annotated[Optional[str], Depends(on(search_api_key))],
|
||||
search_engine_id: Annotated[Optional[str], Depends(on(search_engine_id))],
|
||||
num_results: Annotated[int, "The number of results to return."] = 10,
|
||||
) -> list[dict[str, Any]]:
|
||||
if search_api_key is None or search_engine_id is None:
|
||||
raise ValueError(
|
||||
"Your LLM is not configured to use prebuilt google-search tool.\n"
|
||||
"Please provide search_api_key and search_engine_id.\n"
|
||||
)
|
||||
return _google_search(query, search_api_key, search_engine_id, num_results)
|
||||
|
||||
super().__init__(
|
||||
# GeminiClient will look for a tool with the name "prebuilt_google_search"
|
||||
name="prebuilt_google_search" if use_internal_llm_tool_if_available else "google_search",
|
||||
description="Use the Google Search API to perform a search.",
|
||||
func_or_tool=google_search,
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any, Dict, List, Optional
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ... import Depends, Tool
|
||||
from ...dependency_injection import on
|
||||
|
||||
with optional_import_block():
|
||||
import googleapiclient.errors
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
["googleapiclient"],
|
||||
"google-search",
|
||||
)
|
||||
def _execute_search_query(query: str, youtube_api_key: str, max_results: int) -> Any:
|
||||
"""Execute a YouTube search query using the YouTube Data API.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
youtube_api_key: The API key for the YouTube Data API.
|
||||
max_results: The maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
The search response from the YouTube Data API.
|
||||
"""
|
||||
youtube = build("youtube", "v3", developerKey=youtube_api_key)
|
||||
|
||||
try:
|
||||
search_response = (
|
||||
youtube.search().list(q=query, part="id,snippet", maxResults=max_results, type="video").execute()
|
||||
)
|
||||
|
||||
return search_response
|
||||
except googleapiclient.errors.HttpError as e:
|
||||
logging.error(f"An HTTP error occurred: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
["googleapiclient"],
|
||||
"google-search",
|
||||
)
|
||||
def _get_video_details(video_ids: List[str], youtube_api_key: str) -> Any:
|
||||
"""Get detailed information about specific YouTube videos.
|
||||
|
||||
Args:
|
||||
video_ids: List of YouTube video IDs.
|
||||
youtube_api_key: The API key for the YouTube Data API.
|
||||
|
||||
Returns:
|
||||
The video details response from the YouTube Data API.
|
||||
"""
|
||||
if not video_ids:
|
||||
return {"items": []}
|
||||
|
||||
youtube = build("youtube", "v3", developerKey=youtube_api_key)
|
||||
|
||||
try:
|
||||
videos_response = (
|
||||
youtube.videos().list(id=",".join(video_ids), part="snippet,contentDetails,statistics").execute()
|
||||
)
|
||||
|
||||
return videos_response
|
||||
except googleapiclient.errors.HttpError as e:
|
||||
logging.error(f"An HTTP error occurred: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _youtube_search(
|
||||
query: str,
|
||||
youtube_api_key: str,
|
||||
max_results: int,
|
||||
include_video_details: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search YouTube videos based on a query.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
youtube_api_key: The API key for the YouTube Data API.
|
||||
max_results: The maximum number of results to return.
|
||||
include_video_details: Whether to include detailed video information.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries containing information about the videos.
|
||||
"""
|
||||
search_response = _execute_search_query(query=query, youtube_api_key=youtube_api_key, max_results=max_results)
|
||||
|
||||
results = []
|
||||
video_ids = []
|
||||
|
||||
# Extract basic info from search results
|
||||
for item in search_response.get("items", []):
|
||||
if item["id"]["kind"] == "youtube#video":
|
||||
video_ids.append(item["id"]["videoId"])
|
||||
video_info = {
|
||||
"title": item["snippet"]["title"],
|
||||
"description": item["snippet"]["description"],
|
||||
"publishedAt": item["snippet"]["publishedAt"],
|
||||
"channelTitle": item["snippet"]["channelTitle"],
|
||||
"videoId": item["id"]["videoId"],
|
||||
"url": f"https://www.youtube.com/watch?v={item['id']['videoId']}",
|
||||
}
|
||||
results.append(video_info)
|
||||
|
||||
# If detailed info requested, get it
|
||||
if include_video_details and video_ids:
|
||||
video_details = _get_video_details(video_ids, youtube_api_key)
|
||||
|
||||
# Create a mapping of videoId to details
|
||||
details_map = {item["id"]: item for item in video_details.get("items", [])}
|
||||
|
||||
# Update results with additional details
|
||||
for result in results:
|
||||
video_id = result["videoId"]
|
||||
if video_id in details_map:
|
||||
details = details_map[video_id]
|
||||
result.update({
|
||||
"viewCount": details["statistics"].get("viewCount"),
|
||||
"likeCount": details["statistics"].get("likeCount"),
|
||||
"commentCount": details["statistics"].get("commentCount"),
|
||||
"duration": details["contentDetails"].get("duration"),
|
||||
"definition": details["contentDetails"].get("definition"),
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental")
|
||||
class YoutubeSearchTool(Tool):
|
||||
"""YoutubeSearchTool is a tool that uses the YouTube Data API to search for videos."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
youtube_api_key: Optional[str] = None,
|
||||
):
|
||||
"""Initialize a YouTube search tool.
|
||||
|
||||
Args:
|
||||
youtube_api_key: The API key for the YouTube Data API.
|
||||
"""
|
||||
self.youtube_api_key = youtube_api_key
|
||||
|
||||
if youtube_api_key is None:
|
||||
raise ValueError("youtube_api_key must be provided")
|
||||
|
||||
def youtube_search(
|
||||
query: Annotated[str, "The search query for YouTube videos."],
|
||||
youtube_api_key: Annotated[str, Depends(on(youtube_api_key))],
|
||||
max_results: Annotated[int, "The maximum number of results to return."] = 5,
|
||||
include_video_details: Annotated[bool, "Whether to include detailed video information."] = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Search for YouTube videos based on a query.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
youtube_api_key: The API key for the YouTube Data API.
|
||||
max_results: The maximum number of results to return.
|
||||
include_video_details: Whether to include detailed video information.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries containing information about the videos.
|
||||
"""
|
||||
if youtube_api_key is None:
|
||||
raise ValueError("YouTube API key is required")
|
||||
|
||||
return _youtube_search(query, youtube_api_key, max_results, include_video_details)
|
||||
|
||||
super().__init__(
|
||||
name="youtube_search",
|
||||
description="Search for YouTube videos based on a query, optionally including detailed information.",
|
||||
func_or_tool=youtube_search,
|
||||
)
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .discord import DiscordRetrieveTool, DiscordSendTool
|
||||
from .slack import SlackRetrieveRepliesTool, SlackRetrieveTool, SlackSendTool
|
||||
from .telegram import TelegramRetrieveTool, TelegramSendTool
|
||||
|
||||
__all__ = [
|
||||
"DiscordRetrieveTool",
|
||||
"DiscordSendTool",
|
||||
"SlackRetrieveRepliesTool",
|
||||
"SlackRetrieveTool",
|
||||
"SlackSendTool",
|
||||
"TelegramRetrieveTool",
|
||||
"TelegramSendTool",
|
||||
]
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .discord import DiscordRetrieveTool, DiscordSendTool
|
||||
|
||||
__all__ = ["DiscordRetrieveTool", "DiscordSendTool"]
|
||||
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Union
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from .... import Tool
|
||||
from ....dependency_injection import Depends, on
|
||||
|
||||
__all__ = ["DiscordRetrieveTool", "DiscordSendTool"]
|
||||
|
||||
with optional_import_block():
|
||||
from discord import Client, Intents, utils
|
||||
|
||||
MAX_MESSAGE_LENGTH = 2000
|
||||
MAX_BATCH_RETRIEVE_MESSAGES = 100 # Discord's max per request
|
||||
|
||||
|
||||
@require_optional_import(["discord"], "commsagent-discord")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class DiscordSendTool(Tool):
|
||||
"""Sends a message to a Discord channel."""
|
||||
|
||||
def __init__(self, *, bot_token: str, channel_name: str, guild_name: str) -> None:
|
||||
"""
|
||||
Initialize the DiscordSendTool.
|
||||
|
||||
Args:
|
||||
bot_token: The bot token to use for sending messages.
|
||||
channel_name: The name of the channel to send messages to.
|
||||
guild_name: The name of the guild for the channel.
|
||||
"""
|
||||
|
||||
# Function that sends the message, uses dependency injection for bot token / channel / guild
|
||||
async def discord_send_message(
|
||||
message: Annotated[str, "Message to send to the channel."],
|
||||
bot_token: Annotated[str, Depends(on(bot_token))],
|
||||
guild_name: Annotated[str, Depends(on(guild_name))],
|
||||
channel_name: Annotated[str, Depends(on(channel_name))],
|
||||
) -> Any:
|
||||
"""
|
||||
Sends a message to a Discord channel.
|
||||
|
||||
Args:
|
||||
message: The message to send to the channel.
|
||||
bot_token: The bot token to use for Discord. (uses dependency injection)
|
||||
guild_name: The name of the server. (uses dependency injection)
|
||||
channel_name: The name of the channel. (uses dependency injection)
|
||||
"""
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
intents.guild_messages = True
|
||||
|
||||
client = Client(intents=intents)
|
||||
result_future: asyncio.Future[str] = asyncio.Future() # Stores the result of the send
|
||||
|
||||
# When the client is ready, we'll send the message
|
||||
@client.event # type: ignore[misc]
|
||||
async def on_ready() -> None:
|
||||
try:
|
||||
# Server
|
||||
guild = utils.get(client.guilds, name=guild_name)
|
||||
if guild:
|
||||
# Channel
|
||||
channel = utils.get(guild.text_channels, name=channel_name)
|
||||
if channel:
|
||||
# Send the message
|
||||
if len(message) > MAX_MESSAGE_LENGTH:
|
||||
chunks = [
|
||||
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
|
||||
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
|
||||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
sent = await channel.send(chunk)
|
||||
|
||||
# Store ID for the first chunk
|
||||
if i == 0:
|
||||
sent_message_id = str(sent.id)
|
||||
|
||||
result_future.set_result(
|
||||
f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
|
||||
)
|
||||
else:
|
||||
sent = await channel.send(message)
|
||||
result_future.set_result(f"Message sent successfully (ID: {sent.id}):\n{message}")
|
||||
else:
|
||||
result_future.set_result(f"Message send failed, could not find channel: {channel_name}")
|
||||
else:
|
||||
result_future.set_result(f"Message send failed, could not find guild: {guild_name}")
|
||||
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"Unable to close Discord client: {e}")
|
||||
|
||||
# Start the client and when it's ready it'll send the message in on_ready
|
||||
try:
|
||||
await client.start(bot_token)
|
||||
|
||||
# Capture the result of the send
|
||||
return await result_future
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to start Discord client: {e}")
|
||||
|
||||
super().__init__(
|
||||
name="discord_send",
|
||||
description="Sends a message to a Discord channel.",
|
||||
func_or_tool=discord_send_message,
|
||||
)
|
||||
|
||||
|
||||
@require_optional_import(["discord"], "commsagent-discord")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class DiscordRetrieveTool(Tool):
|
||||
"""Retrieves messages from a Discord channel."""
|
||||
|
||||
def __init__(self, *, bot_token: str, channel_name: str, guild_name: str) -> None:
|
||||
"""
|
||||
Initialize the DiscordRetrieveTool.
|
||||
|
||||
Args:
|
||||
bot_token: The bot token to use for retrieving messages.
|
||||
channel_name: The name of the channel to retrieve messages from.
|
||||
guild_name: The name of the guild for the channel.
|
||||
"""
|
||||
|
||||
async def discord_retrieve_messages(
|
||||
bot_token: Annotated[str, Depends(on(bot_token))],
|
||||
guild_name: Annotated[str, Depends(on(guild_name))],
|
||||
channel_name: Annotated[str, Depends(on(channel_name))],
|
||||
messages_since: Annotated[
|
||||
Union[str, None],
|
||||
"Date to retrieve messages from (ISO format) OR Discord snowflake ID. If None, retrieves latest messages.",
|
||||
] = None,
|
||||
maximum_messages: Annotated[
|
||||
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
|
||||
] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieves messages from a Discord channel.
|
||||
|
||||
Args:
|
||||
bot_token: The bot token to use for Discord. (uses dependency injection)
|
||||
guild_name: The name of the server. (uses dependency injection)
|
||||
channel_name: The name of the channel. (uses dependency injection)
|
||||
messages_since: ISO format date string OR Discord snowflake ID, to retrieve messages from. If None, retrieves latest messages.
|
||||
maximum_messages: Maximum number of messages to retrieve. If None, retrieves all messages since date.
|
||||
"""
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
intents.guilds = True
|
||||
intents.guild_messages = True
|
||||
|
||||
client = Client(intents=intents)
|
||||
result_future: asyncio.Future[list[dict[str, Any]]] = asyncio.Future()
|
||||
|
||||
messages_since_date: Union[str, None] = None
|
||||
if messages_since is not None:
|
||||
if DiscordRetrieveTool._is_snowflake(messages_since):
|
||||
messages_since_date = DiscordRetrieveTool._snowflake_to_iso(messages_since)
|
||||
else:
|
||||
messages_since_date = messages_since
|
||||
|
||||
@client.event # type: ignore[misc]
|
||||
async def on_ready() -> None:
|
||||
try:
|
||||
messages = []
|
||||
|
||||
# Get guild and channel
|
||||
guild = utils.get(client.guilds, name=guild_name)
|
||||
if not guild:
|
||||
result_future.set_result([{"error": f"Could not find guild: {guild_name}"}])
|
||||
return
|
||||
|
||||
channel = utils.get(guild.text_channels, name=channel_name)
|
||||
if not channel:
|
||||
result_future.set_result([{"error": f"Could not find channel: {channel_name}"}])
|
||||
return
|
||||
|
||||
# Setup retrieval parameters
|
||||
last_message_id = None
|
||||
messages_retrieved = 0
|
||||
|
||||
# Convert to ISO format
|
||||
after_date = None
|
||||
if messages_since_date:
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
after_date = datetime.fromisoformat(messages_since_date)
|
||||
except ValueError:
|
||||
result_future.set_result([
|
||||
{"error": f"Invalid date format: {messages_since_date}. Use ISO format."}
|
||||
])
|
||||
return
|
||||
|
||||
while True:
|
||||
# Setup fetch options
|
||||
fetch_options = {
|
||||
"limit": MAX_BATCH_RETRIEVE_MESSAGES,
|
||||
"before": last_message_id if last_message_id else None,
|
||||
"after": after_date if after_date else None,
|
||||
}
|
||||
|
||||
# Fetch batch of messages
|
||||
message_batch = []
|
||||
async for message in channel.history(**fetch_options): # type: ignore[arg-type]
|
||||
message_batch.append(message)
|
||||
messages_retrieved += 1
|
||||
|
||||
# Check if we've reached the maximum
|
||||
if maximum_messages and messages_retrieved >= maximum_messages:
|
||||
break
|
||||
|
||||
if not message_batch:
|
||||
break
|
||||
|
||||
# Process messages
|
||||
for msg in message_batch:
|
||||
messages.append({
|
||||
"id": str(msg.id),
|
||||
"content": msg.content,
|
||||
"author": str(msg.author),
|
||||
"timestamp": msg.created_at.isoformat(),
|
||||
})
|
||||
|
||||
# Update last message ID for pagination
|
||||
last_message_id = message_batch[-1] # Use message object directly as 'before' parameter
|
||||
|
||||
# Break if we've reached the maximum
|
||||
if maximum_messages and messages_retrieved >= maximum_messages:
|
||||
break
|
||||
|
||||
result_future.set_result(messages)
|
||||
|
||||
except Exception as e:
|
||||
result_future.set_exception(e)
|
||||
finally:
|
||||
try:
|
||||
await client.close()
|
||||
except Exception as e:
|
||||
raise Exception(f"Unable to close Discord client: {e}")
|
||||
|
||||
try:
|
||||
await client.start(bot_token)
|
||||
return await result_future
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to start Discord client: {e}")
|
||||
|
||||
super().__init__(
|
||||
name="discord_retrieve",
|
||||
description="Retrieves messages from a Discord channel based datetime/message ID and/or number of latest messages.",
|
||||
func_or_tool=discord_retrieve_messages,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_snowflake(value: str) -> bool:
|
||||
"""Check if a string is a valid Discord snowflake ID."""
|
||||
# Must be numeric and 17-20 digits
|
||||
if not value.isdigit():
|
||||
return False
|
||||
|
||||
digit_count = len(value)
|
||||
return 17 <= digit_count <= 20
|
||||
|
||||
@staticmethod
|
||||
def _snowflake_to_iso(snowflake: str) -> str:
|
||||
"""Convert a Discord snowflake ID to ISO timestamp string."""
|
||||
if not DiscordRetrieveTool._is_snowflake(snowflake):
|
||||
raise ValueError(f"Invalid snowflake ID: {snowflake}")
|
||||
|
||||
# Discord epoch (2015-01-01)
|
||||
discord_epoch = 1420070400000
|
||||
|
||||
# Convert ID to int and shift right 22 bits to get timestamp
|
||||
timestamp_ms = (int(snowflake) >> 22) + discord_epoch
|
||||
|
||||
# Convert to datetime and format as ISO string
|
||||
dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc)
|
||||
return dt.isoformat()
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .slack import SlackRetrieveRepliesTool, SlackRetrieveTool, SlackSendTool
|
||||
|
||||
__all__ = ["SlackRetrieveRepliesTool", "SlackRetrieveTool", "SlackSendTool"]
|
||||
@@ -0,0 +1,391 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, Optional, Tuple, Union
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from .... import Tool
|
||||
from ....dependency_injection import Depends, on
|
||||
|
||||
__all__ = ["SlackSendTool"]
|
||||
|
||||
with optional_import_block():
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
MAX_MESSAGE_LENGTH = 40000
|
||||
|
||||
|
||||
@require_optional_import(["slack_sdk"], "commsagent-slack")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class SlackSendTool(Tool):
|
||||
"""Sends a message to a Slack channel."""
|
||||
|
||||
def __init__(self, *, bot_token: str, channel_id: str) -> None:
|
||||
"""
|
||||
Initialize the SlackSendTool.
|
||||
|
||||
Args:
|
||||
bot_token: Bot User OAuth Token starting with "xoxb-".
|
||||
channel_id: Channel ID where messages will be sent.
|
||||
"""
|
||||
|
||||
# Function that sends the message, uses dependency injection for bot token / channel / guild
|
||||
async def slack_send_message(
|
||||
message: Annotated[str, "Message to send to the channel."],
|
||||
bot_token: Annotated[str, Depends(on(bot_token))],
|
||||
channel_id: Annotated[str, Depends(on(channel_id))],
|
||||
) -> Any:
|
||||
"""
|
||||
Sends a message to a Slack channel.
|
||||
|
||||
Args:
|
||||
message: The message to send to the channel.
|
||||
bot_token: The bot token to use for Slack. (uses dependency injection)
|
||||
channel_id: The ID of the channel. (uses dependency injection)
|
||||
"""
|
||||
try:
|
||||
web_client = WebClient(token=bot_token)
|
||||
|
||||
# Send the message
|
||||
if len(message) > MAX_MESSAGE_LENGTH:
|
||||
chunks = [
|
||||
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
|
||||
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
|
||||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
response = web_client.chat_postMessage(channel=channel_id, text=chunk)
|
||||
|
||||
if not response["ok"]:
|
||||
return f"Message send failed on chunk {i + 1}, Slack response error: {response['error']}"
|
||||
|
||||
# Store ID for the first chunk
|
||||
if i == 0:
|
||||
sent_message_id = response["ts"]
|
||||
|
||||
return f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
|
||||
else:
|
||||
response = web_client.chat_postMessage(channel=channel_id, text=message)
|
||||
|
||||
if not response["ok"]:
|
||||
return f"Message send failed, Slack response error: {response['error']}"
|
||||
|
||||
return f"Message sent successfully (ID: {response['ts']}):\n{message}"
|
||||
except SlackApiError as e:
|
||||
return f"Message send failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
|
||||
except Exception as e:
|
||||
return f"Message send failed, exception: {e}"
|
||||
|
||||
super().__init__(
|
||||
name="slack_send",
|
||||
description="Sends a message to a Slack channel.",
|
||||
func_or_tool=slack_send_message,
|
||||
)
|
||||
|
||||
|
||||
@require_optional_import(["slack_sdk"], "commsagent-slack")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class SlackRetrieveTool(Tool):
|
||||
"""Retrieves messages from a Slack channel."""
|
||||
|
||||
def __init__(self, *, bot_token: str, channel_id: str) -> None:
|
||||
"""
|
||||
Initialize the SlackRetrieveTool.
|
||||
|
||||
Args:
|
||||
bot_token: Bot User OAuth Token starting with "xoxb-".
|
||||
channel_id: Channel ID where messages will be sent.
|
||||
"""
|
||||
|
||||
async def slack_retrieve_messages(
|
||||
bot_token: Annotated[str, Depends(on(bot_token))],
|
||||
channel_id: Annotated[str, Depends(on(channel_id))],
|
||||
messages_since: Annotated[
|
||||
Union[str, None],
|
||||
"Date to retrieve messages from (ISO format) OR Slack message ID. If None, retrieves latest messages.",
|
||||
] = None,
|
||||
maximum_messages: Annotated[
|
||||
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
|
||||
] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieves messages from a Discord channel.
|
||||
|
||||
Args:
|
||||
bot_token: The bot token to use for Discord. (uses dependency injection)
|
||||
channel_id: The ID of the channel. (uses dependency injection)
|
||||
messages_since: ISO format date string OR Slack message ID, to retrieve messages from. If None, retrieves latest messages.
|
||||
maximum_messages: Maximum number of messages to retrieve. If None, retrieves all messages since date.
|
||||
"""
|
||||
try:
|
||||
web_client = WebClient(token=bot_token)
|
||||
|
||||
# Convert ISO datetime to Unix timestamp if needed
|
||||
oldest = None
|
||||
if messages_since:
|
||||
if "." in messages_since: # Likely a Slack message ID
|
||||
oldest = messages_since
|
||||
else: # Assume ISO format
|
||||
try:
|
||||
dt = datetime.fromisoformat(messages_since.replace("Z", "+00:00"))
|
||||
oldest = str(dt.timestamp())
|
||||
except ValueError as e:
|
||||
return f"Invalid date format. Please provide either a Slack message ID or ISO format date (e.g., '2025-01-25T00:00:00Z'). Error: {e}"
|
||||
|
||||
messages = []
|
||||
cursor = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Prepare API call parameters
|
||||
params = {
|
||||
"channel": channel_id,
|
||||
"limit": min(1000, maximum_messages) if maximum_messages else 1000,
|
||||
}
|
||||
if oldest:
|
||||
params["oldest"] = oldest
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
# Make API call
|
||||
response = web_client.conversations_history(**params) # type: ignore[arg-type]
|
||||
|
||||
if not response["ok"]:
|
||||
return f"Message retrieval failed, Slack response error: {response['error']}"
|
||||
|
||||
# Add messages to our list
|
||||
messages.extend(response["messages"])
|
||||
|
||||
# Check if we've hit our maximum
|
||||
if maximum_messages and len(messages) >= maximum_messages:
|
||||
messages = messages[:maximum_messages]
|
||||
break
|
||||
|
||||
# Check if there are more messages
|
||||
if not response["has_more"]:
|
||||
break
|
||||
|
||||
cursor = response["response_metadata"]["next_cursor"]
|
||||
|
||||
except SlackApiError as e:
|
||||
return f"Message retrieval failed on pagination, Slack API error: {e.response['error']}"
|
||||
|
||||
return {
|
||||
"message_count": len(messages),
|
||||
"messages": messages,
|
||||
"start_time": oldest or "latest",
|
||||
}
|
||||
|
||||
except SlackApiError as e:
|
||||
return f"Message retrieval failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
|
||||
except Exception as e:
|
||||
return f"Message retrieval failed, exception: {e}"
|
||||
|
||||
super().__init__(
|
||||
name="slack_retrieve",
|
||||
description="Retrieves messages from a Slack channel based datetime/message ID and/or number of latest messages.",
|
||||
func_or_tool=slack_retrieve_messages,
|
||||
)
|
||||
|
||||
|
||||
@require_optional_import(["slack_sdk"], "commsagent-slack")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class SlackRetrieveRepliesTool(Tool):
|
||||
"""Retrieves replies to a specific Slack message from both threads and the channel."""
|
||||
|
||||
def __init__(self, *, bot_token: str, channel_id: str) -> None:
|
||||
"""
|
||||
Initialize the SlackRetrieveRepliesTool.
|
||||
|
||||
Args:
|
||||
bot_token: Bot User OAuth Token starting with "xoxb-".
|
||||
channel_id: Channel ID where the parent message exists.
|
||||
"""
|
||||
|
||||
async def slack_retrieve_replies(
|
||||
message_ts: Annotated[str, "Timestamp (ts) of the parent message to retrieve replies for."],
|
||||
bot_token: Annotated[str, Depends(on(bot_token))],
|
||||
channel_id: Annotated[str, Depends(on(channel_id))],
|
||||
min_replies: Annotated[
|
||||
Optional[int],
|
||||
"Minimum number of replies to wait for before returning (thread + channel). If None, returns immediately.",
|
||||
] = None,
|
||||
timeout_seconds: Annotated[
|
||||
int, "Maximum time in seconds to wait for the requested number of replies."
|
||||
] = 60,
|
||||
poll_interval: Annotated[int, "Time in seconds between polling attempts when waiting for replies."] = 5,
|
||||
include_channel_messages: Annotated[
|
||||
bool, "Whether to include messages in the channel after the original message."
|
||||
] = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieves replies to a specific Slack message, from both threads and the main channel.
|
||||
|
||||
Args:
|
||||
message_ts: The timestamp (ts) identifier of the parent message.
|
||||
bot_token: The bot token to use for Slack. (uses dependency injection)
|
||||
channel_id: The ID of the channel. (uses dependency injection)
|
||||
min_replies: Minimum number of combined replies to wait for before returning. If None, returns immediately.
|
||||
timeout_seconds: Maximum time in seconds to wait for the requested number of replies.
|
||||
poll_interval: Time in seconds between polling attempts when waiting for replies.
|
||||
include_channel_messages: Whether to include messages posted in the channel after the original message.
|
||||
"""
|
||||
try:
|
||||
web_client = WebClient(token=bot_token)
|
||||
|
||||
# Function to get current thread replies
|
||||
async def get_thread_replies() -> tuple[Optional[list[dict[str, Any]]], Optional[str]]:
|
||||
try:
|
||||
response = web_client.conversations_replies(
|
||||
channel=channel_id,
|
||||
ts=message_ts,
|
||||
)
|
||||
|
||||
if not response["ok"]:
|
||||
return None, f"Thread reply retrieval failed, Slack response error: {response['error']}"
|
||||
|
||||
# The first message is the parent message itself, so exclude it when counting replies
|
||||
replies = response["messages"][1:] if len(response["messages"]) > 0 else []
|
||||
return replies, None
|
||||
|
||||
except SlackApiError as e:
|
||||
return None, f"Thread reply retrieval failed, Slack API exception: {e.response['error']}"
|
||||
except Exception as e:
|
||||
return None, f"Thread reply retrieval failed, exception: {e}"
|
||||
|
||||
# Function to get messages in the channel after the original message
|
||||
async def get_channel_messages() -> Tuple[Optional[list[dict[str, Any]]], Optional[str]]:
|
||||
try:
|
||||
response = web_client.conversations_history(
|
||||
channel=channel_id,
|
||||
oldest=message_ts, # Start from the original message timestamp
|
||||
inclusive=False, # Don't include the original message
|
||||
)
|
||||
|
||||
if not response["ok"]:
|
||||
return None, f"Channel message retrieval failed, Slack response error: {response['error']}"
|
||||
|
||||
# Return all messages in the channel after the original message
|
||||
# We need to filter out any that are part of the thread we're already getting
|
||||
messages = []
|
||||
for msg in response["messages"]:
|
||||
# Skip if the message is part of the thread we're already retrieving
|
||||
if "thread_ts" in msg and msg["thread_ts"] == message_ts:
|
||||
continue
|
||||
messages.append(msg)
|
||||
|
||||
return messages, None
|
||||
|
||||
except SlackApiError as e:
|
||||
return None, f"Channel message retrieval failed, Slack API exception: {e.response['error']}"
|
||||
except Exception as e:
|
||||
return None, f"Channel message retrieval failed, exception: {e}"
|
||||
|
||||
# Function to get all replies (both thread and channel)
|
||||
async def get_all_replies() -> Tuple[
|
||||
Optional[list[dict[str, Any]]], Optional[list[dict[str, Any]]], Optional[str]
|
||||
]:
|
||||
thread_replies, thread_error = await get_thread_replies()
|
||||
if thread_error:
|
||||
return None, None, thread_error
|
||||
|
||||
channel_messages: list[dict[str, Any]] = []
|
||||
channel_error = None
|
||||
|
||||
if include_channel_messages:
|
||||
channel_results, channel_error = await get_channel_messages()
|
||||
if channel_error:
|
||||
return thread_replies, None, channel_error
|
||||
channel_messages = channel_results if channel_results is not None else []
|
||||
|
||||
return thread_replies, channel_messages, None
|
||||
|
||||
# If no waiting is required, just get replies and return
|
||||
if min_replies is None:
|
||||
thread_replies, channel_messages, error = await get_all_replies()
|
||||
if error:
|
||||
return error
|
||||
|
||||
thread_replies_list: list[dict[str, Any]] = [] if thread_replies is None else thread_replies
|
||||
channel_messages_list: list[dict[str, Any]] = [] if channel_messages is None else channel_messages
|
||||
|
||||
# Combine replies for counting but keep them separate in the result
|
||||
total_reply_count = len(thread_replies_list) + len(channel_messages_list)
|
||||
|
||||
return {
|
||||
"parent_message_ts": message_ts,
|
||||
"total_reply_count": total_reply_count,
|
||||
"thread_replies": thread_replies_list,
|
||||
"thread_reply_count": len(thread_replies_list),
|
||||
"channel_messages": channel_messages_list if include_channel_messages else None,
|
||||
"channel_message_count": len(channel_messages_list) if include_channel_messages else None,
|
||||
}
|
||||
|
||||
# Wait for the required number of replies with timeout
|
||||
start_time = datetime.now()
|
||||
end_time = start_time + timedelta(seconds=timeout_seconds)
|
||||
|
||||
while datetime.now() < end_time:
|
||||
thread_replies, channel_messages, error = await get_all_replies()
|
||||
if error:
|
||||
return error
|
||||
|
||||
thread_replies_current: list[dict[str, Any]] = [] if thread_replies is None else thread_replies
|
||||
channel_messages_current: list[dict[str, Any]] = (
|
||||
[] if channel_messages is None else channel_messages
|
||||
)
|
||||
|
||||
# Combine replies for counting
|
||||
total_reply_count = len(thread_replies_current) + len(channel_messages_current)
|
||||
|
||||
# If we have enough total replies, return them
|
||||
if total_reply_count >= min_replies:
|
||||
return {
|
||||
"parent_message_ts": message_ts,
|
||||
"total_reply_count": total_reply_count,
|
||||
"thread_replies": thread_replies_current,
|
||||
"thread_reply_count": len(thread_replies_current),
|
||||
"channel_messages": channel_messages_current if include_channel_messages else None,
|
||||
"channel_message_count": len(channel_messages_current)
|
||||
if include_channel_messages
|
||||
else None,
|
||||
"waited_seconds": (datetime.now() - start_time).total_seconds(),
|
||||
}
|
||||
|
||||
# Wait before checking again
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
# If we reach here, we timed out waiting for replies
|
||||
thread_replies, channel_messages, error = await get_all_replies()
|
||||
if error:
|
||||
return error
|
||||
|
||||
# Combine replies for counting
|
||||
total_reply_count = len(thread_replies or []) + len(channel_messages or [])
|
||||
|
||||
return {
|
||||
"parent_message_ts": message_ts,
|
||||
"total_reply_count": total_reply_count,
|
||||
"thread_replies": thread_replies or [],
|
||||
"thread_reply_count": len(thread_replies or []),
|
||||
"channel_messages": channel_messages or [] if include_channel_messages else None,
|
||||
"channel_message_count": len(channel_messages or []) if include_channel_messages else None,
|
||||
"timed_out": True,
|
||||
"waited_seconds": timeout_seconds,
|
||||
"requested_replies": min_replies,
|
||||
}
|
||||
|
||||
except SlackApiError as e:
|
||||
return f"Reply retrieval failed, Slack API exception: {e.response['error']} (See https://api.slack.com/automation/cli/errors#{e.response['error']})"
|
||||
except Exception as e:
|
||||
return f"Reply retrieval failed, exception: {e}"
|
||||
|
||||
super().__init__(
|
||||
name="slack_retrieve_replies",
|
||||
description="Retrieves replies to a specific Slack message, checking both thread replies and messages in the channel after the original message.",
|
||||
func_or_tool=slack_retrieve_replies,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .telegram import TelegramRetrieveTool, TelegramSendTool
|
||||
|
||||
__all__ = ["TelegramRetrieveTool", "TelegramSendTool"]
|
||||
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Union
|
||||
|
||||
from .....doc_utils import export_module
|
||||
from .....import_utils import optional_import_block, require_optional_import
|
||||
from .... import Tool
|
||||
from ....dependency_injection import Depends, on
|
||||
|
||||
__all__ = ["TelegramRetrieveTool", "TelegramSendTool"]
|
||||
|
||||
with optional_import_block():
|
||||
from telethon import TelegramClient
|
||||
from telethon.tl.types import InputMessagesFilterEmpty, Message, PeerChannel, PeerChat, PeerUser
|
||||
|
||||
MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
|
||||
@require_optional_import(["telethon", "telethon.tl.types"], "commsagent-telegram")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class BaseTelegramTool:
|
||||
"""Base class for Telegram tools containing shared functionality."""
|
||||
|
||||
def __init__(self, api_id: str, api_hash: str, session_name: str) -> None:
|
||||
self._api_id = api_id
|
||||
self._api_hash = api_hash
|
||||
self._session_name = session_name
|
||||
|
||||
def _get_client(self) -> "TelegramClient": # type: ignore[no-any-unimported]
|
||||
"""Get a fresh TelegramClient instance."""
|
||||
return TelegramClient(self._session_name, self._api_id, self._api_hash)
|
||||
|
||||
@staticmethod
|
||||
def _get_peer_from_id(chat_id: str) -> Union["PeerChat", "PeerChannel", "PeerUser"]: # type: ignore[no-any-unimported]
|
||||
"""Convert a chat ID string to appropriate Peer type."""
|
||||
try:
|
||||
# Convert string to integer
|
||||
id_int = int(chat_id)
|
||||
|
||||
# Channel/Supergroup: -100 prefix
|
||||
if str(chat_id).startswith("-100"):
|
||||
channel_id = int(str(chat_id)[4:]) # Remove -100 prefix
|
||||
return PeerChannel(channel_id)
|
||||
|
||||
# Group: negative number without -100 prefix
|
||||
elif id_int < 0:
|
||||
group_id = -id_int # Remove the negative sign
|
||||
return PeerChat(group_id)
|
||||
|
||||
# User/Bot: positive number
|
||||
else:
|
||||
return PeerUser(id_int)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid chat_id format: {chat_id}. Error: {str(e)}")
|
||||
|
||||
async def _initialize_entity(self, client: "TelegramClient", chat_id: str) -> Any: # type: ignore[no-any-unimported]
|
||||
"""Initialize and cache the entity by trying different methods."""
|
||||
peer = self._get_peer_from_id(chat_id)
|
||||
|
||||
try:
|
||||
# Try direct entity resolution first
|
||||
entity = await client.get_entity(peer)
|
||||
return entity
|
||||
except ValueError:
|
||||
try:
|
||||
# Get all dialogs (conversations)
|
||||
async for dialog in client.iter_dialogs():
|
||||
# For users/bots, we need to find the dialog with the user
|
||||
if (
|
||||
isinstance(peer, PeerUser)
|
||||
and dialog.entity.id == peer.user_id
|
||||
or dialog.entity.id == getattr(peer, "channel_id", getattr(peer, "chat_id", None))
|
||||
):
|
||||
return dialog.entity
|
||||
|
||||
# If we get here, we didn't find the entity in dialogs
|
||||
raise ValueError(f"Could not find entity {chat_id} in dialogs")
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not initialize entity for {chat_id}. "
|
||||
f"Make sure you have access to this chat. Error: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@require_optional_import(["telethon"], "commsagent-telegram")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class TelegramSendTool(BaseTelegramTool, Tool):
|
||||
"""Sends a message to a Telegram channel, group, or user."""
|
||||
|
||||
def __init__(self, *, api_id: str, api_hash: str, chat_id: str) -> None:
|
||||
"""
|
||||
Initialize the TelegramSendTool.
|
||||
|
||||
Args:
|
||||
api_id: Telegram API ID from https://my.telegram.org/apps.
|
||||
api_hash: Telegram API hash from https://my.telegram.org/apps.
|
||||
chat_id: The ID of the destination (Channel, Group, or User ID).
|
||||
"""
|
||||
BaseTelegramTool.__init__(self, api_id, api_hash, "telegram_send_session")
|
||||
|
||||
async def telegram_send_message(
|
||||
message: Annotated[str, "Message to send to the chat."],
|
||||
chat_id: Annotated[str, Depends(on(chat_id))],
|
||||
) -> Any:
|
||||
"""
|
||||
Sends a message to a Telegram chat.
|
||||
|
||||
Args:
|
||||
message: The message to send.
|
||||
chat_id: The ID of the destination. (uses dependency injection)
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
async with client:
|
||||
# Initialize and cache the entity
|
||||
entity = await self._initialize_entity(client, chat_id)
|
||||
|
||||
if len(message) > MAX_MESSAGE_LENGTH:
|
||||
chunks = [
|
||||
message[i : i + (MAX_MESSAGE_LENGTH - 1)]
|
||||
for i in range(0, len(message), (MAX_MESSAGE_LENGTH - 1))
|
||||
]
|
||||
first_message: Union[Message, None] = None # type: ignore[no-any-unimported]
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
sent = await client.send_message(
|
||||
entity=entity,
|
||||
message=chunk,
|
||||
parse_mode="html",
|
||||
reply_to=first_message.id if first_message else None,
|
||||
)
|
||||
|
||||
# Store the first message to chain replies
|
||||
if i == 0:
|
||||
first_message = sent
|
||||
sent_message_id = str(sent.id)
|
||||
|
||||
return (
|
||||
f"Message sent successfully ({len(chunks)} chunks, first ID: {sent_message_id}):\n{message}"
|
||||
)
|
||||
else:
|
||||
sent = await client.send_message(entity=entity, message=message, parse_mode="html")
|
||||
return f"Message sent successfully (ID: {sent.id}):\n{message}"
|
||||
|
||||
except Exception as e:
|
||||
return f"Message send failed, exception: {str(e)}"
|
||||
|
||||
Tool.__init__(
|
||||
self,
|
||||
name="telegram_send",
|
||||
description="Sends a message to a personal channel, bot channel, group, or channel.",
|
||||
func_or_tool=telegram_send_message,
|
||||
)
|
||||
|
||||
|
||||
@require_optional_import(["telethon"], "commsagent-telegram")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class TelegramRetrieveTool(BaseTelegramTool, Tool):
|
||||
"""Retrieves messages from a Telegram channel."""
|
||||
|
||||
def __init__(self, *, api_id: str, api_hash: str, chat_id: str) -> None:
|
||||
"""
|
||||
Initialize the TelegramRetrieveTool.
|
||||
|
||||
Args:
|
||||
api_id: Telegram API ID from https://my.telegram.org/apps.
|
||||
api_hash: Telegram API hash from https://my.telegram.org/apps.
|
||||
chat_id: The ID of the chat to retrieve messages from (Channel, Group, Bot Chat ID).
|
||||
"""
|
||||
BaseTelegramTool.__init__(self, api_id, api_hash, "telegram_retrieve_session")
|
||||
self._chat_id = chat_id
|
||||
|
||||
async def telegram_retrieve_messages(
|
||||
chat_id: Annotated[str, Depends(on(chat_id))],
|
||||
messages_since: Annotated[
|
||||
Union[str, None],
|
||||
"Date to retrieve messages from (ISO format) OR message ID. If None, retrieves latest messages.",
|
||||
] = None,
|
||||
maximum_messages: Annotated[
|
||||
Union[int, None], "Maximum number of messages to retrieve. If None, retrieves all messages since date."
|
||||
] = None,
|
||||
search: Annotated[Union[str, None], "Optional string to search for in messages."] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Retrieves messages from a Telegram chat.
|
||||
|
||||
Args:
|
||||
chat_id: The ID of the chat. (uses dependency injection)
|
||||
messages_since: ISO format date string OR message ID to retrieve messages from.
|
||||
maximum_messages: Maximum number of messages to retrieve.
|
||||
search: Optional string to search for in messages.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
async with client:
|
||||
# Initialize and cache the entity
|
||||
entity = await self._initialize_entity(client, chat_id)
|
||||
|
||||
# Setup retrieval parameters
|
||||
params = {
|
||||
"entity": entity,
|
||||
"limit": maximum_messages if maximum_messages else None,
|
||||
"search": search if search else None,
|
||||
"filter": InputMessagesFilterEmpty(),
|
||||
"wait_time": None, # No wait time between requests
|
||||
}
|
||||
|
||||
# Handle messages_since parameter
|
||||
if messages_since:
|
||||
try:
|
||||
# Try to parse as message ID first
|
||||
msg_id = int(messages_since)
|
||||
params["min_id"] = msg_id
|
||||
except ValueError:
|
||||
# Not a message ID, try as ISO date
|
||||
try:
|
||||
date = datetime.fromisoformat(messages_since.replace("Z", "+00:00"))
|
||||
params["offset_date"] = date
|
||||
params["reverse"] = (
|
||||
True # Need this because the date gets messages before a certain date by default
|
||||
)
|
||||
except ValueError:
|
||||
return {
|
||||
"error": "Invalid messages_since format. Please provide either a message ID or ISO format date (e.g., '2025-01-25T00:00:00Z')"
|
||||
}
|
||||
|
||||
# Retrieve messages
|
||||
messages = []
|
||||
count = 0
|
||||
# For bot users, we need to get both sent and received messages
|
||||
if isinstance(self._get_peer_from_id(chat_id), PeerUser):
|
||||
print(f"Retrieving messages for bot chat {chat_id}")
|
||||
|
||||
async for message in client.iter_messages(**params):
|
||||
count += 1
|
||||
messages.append({
|
||||
"id": str(message.id),
|
||||
"date": message.date.isoformat(),
|
||||
"from_id": str(message.from_id) if message.from_id else None,
|
||||
"text": message.text,
|
||||
"reply_to_msg_id": str(message.reply_to_msg_id) if message.reply_to_msg_id else None,
|
||||
"forward_from": str(message.forward.from_id) if message.forward else None,
|
||||
"edit_date": message.edit_date.isoformat() if message.edit_date else None,
|
||||
"media": bool(message.media),
|
||||
"entities": [
|
||||
{"type": e.__class__.__name__, "offset": e.offset, "length": e.length}
|
||||
for e in message.entities
|
||||
]
|
||||
if message.entities
|
||||
else None,
|
||||
})
|
||||
|
||||
# Check if we've hit the maximum
|
||||
if maximum_messages and len(messages) >= maximum_messages:
|
||||
break
|
||||
|
||||
return {
|
||||
"message_count": len(messages),
|
||||
"messages": messages,
|
||||
"start_time": messages_since or "latest",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return f"Message retrieval failed, exception: {str(e)}"
|
||||
|
||||
Tool.__init__(
|
||||
self,
|
||||
name="telegram_retrieve",
|
||||
description="Retrieves messages from a Telegram chat based on datetime/message ID and/or number of latest messages.",
|
||||
func_or_tool=telegram_retrieve_messages,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .perplexity_search import PerplexitySearchTool
|
||||
|
||||
__all__ = ["PerplexitySearchTool"]
|
||||
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Module: perplexity_search_tool
|
||||
Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
This module provides classes for interacting with the Perplexity AI search API.
|
||||
It defines data models for responses and a tool for executing web and conversational searches.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from autogen.tools import Tool
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""
|
||||
Represents a message in the chat conversation.
|
||||
|
||||
Attributes:
|
||||
role (str): The role of the message sender (e.g., "system", "user").
|
||||
content (str): The text content of the message.
|
||||
"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
"""
|
||||
Model representing token usage details.
|
||||
|
||||
Attributes:
|
||||
prompt_tokens (int): The number of tokens used for the prompt.
|
||||
completion_tokens (int): The number of tokens generated in the completion.
|
||||
total_tokens (int): The total number of tokens (prompt + completion).
|
||||
search_context_size (str): The size context used in the search (e.g., "high").
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
search_context_size: str
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
"""
|
||||
Represents one choice in the response from the Perplexity API.
|
||||
|
||||
Attributes:
|
||||
index (int): The index of this choice.
|
||||
finish_reason (str): The reason why the API finished generating this choice.
|
||||
message (Message): The message object containing the response text.
|
||||
"""
|
||||
|
||||
index: int
|
||||
finish_reason: str
|
||||
message: Message
|
||||
|
||||
|
||||
class PerplexityChatCompletionResponse(BaseModel):
|
||||
"""
|
||||
Represents the full chat completion response from the Perplexity API.
|
||||
|
||||
Attributes:
|
||||
id (str): Unique identifier for the response.
|
||||
model (str): The model name used for generating the response.
|
||||
created (int): Timestamp when the response was created.
|
||||
usage (Usage): Token usage details.
|
||||
citations (list[str]): list of citation strings included in the response.
|
||||
object (str): Type of the response object.
|
||||
choices (list[Choice]): list of choices returned by the API.
|
||||
"""
|
||||
|
||||
id: str
|
||||
model: str
|
||||
created: int
|
||||
usage: Usage
|
||||
citations: list[str]
|
||||
object: str
|
||||
choices: list[Choice]
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
"""
|
||||
Represents the response from a search query.
|
||||
|
||||
Attributes:
|
||||
content (Optional[str]): The textual content returned from the search.
|
||||
citations (Optional[list[str]]): A list of citation URLs relevant to the search result.
|
||||
error (Optional[str]): An error message if the search failed.
|
||||
"""
|
||||
|
||||
content: Union[str, None]
|
||||
citations: Union[list[str], None]
|
||||
error: Union[str, None]
|
||||
|
||||
|
||||
class PerplexitySearchTool(Tool):
|
||||
"""
|
||||
Tool for interacting with the Perplexity AI search API.
|
||||
|
||||
This tool uses the Perplexity API to perform web search, news search,
|
||||
and conversational search, returning concise and precise responses.
|
||||
|
||||
Attributes:
|
||||
url (str): API endpoint URL.
|
||||
model (str): Name of the model to be used.
|
||||
api_key (str): API key for authenticating with the Perplexity API.
|
||||
max_tokens (int): Maximum tokens allowed for the API response.
|
||||
search_domain_filters (Optional[list[str]]): Optional list of domain filters for the search.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "sonar",
|
||||
api_key: Optional[str] = None,
|
||||
max_tokens: int = 1000,
|
||||
search_domain_filter: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Initializes a new instance of the PerplexitySearchTool.
|
||||
|
||||
Args:
|
||||
model (str, optional): The model to use. Defaults to "sonar".
|
||||
api_key (Optional[str], optional): API key for authentication.
|
||||
max_tokens (int, optional): Maximum number of tokens for the response. Defaults to 1000.
|
||||
search_domain_filter (Optional[list[str]], optional): list of domain filters to restrict search.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is missing, the model is empty, max_tokens is not positive,
|
||||
or if search_domain_filter is not a list when provided.
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("PERPLEXITY_API_KEY")
|
||||
self._validate_tool_config(model, self.api_key, max_tokens, search_domain_filter)
|
||||
self.url = "https://api.perplexity.ai/chat/completions"
|
||||
self.model = model
|
||||
self.max_tokens = max_tokens
|
||||
self.search_domain_filters = search_domain_filter
|
||||
super().__init__(
|
||||
name="perplexity-search",
|
||||
description="Perplexity AI search tool for web search, news search, and conversational search "
|
||||
"for finding answers to everyday questions, conducting in-depth research and analysis.",
|
||||
func_or_tool=self.search,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_tool_config(
|
||||
model: str, api_key: Union[str, None], max_tokens: int, search_domain_filter: Union[list[str], None]
|
||||
) -> None:
|
||||
"""
|
||||
Validates the configuration parameters for the search tool.
|
||||
|
||||
Args:
|
||||
model (str): The model to use.
|
||||
api_key (Union[str, None]): The API key for authentication.
|
||||
max_tokens (int): Maximum tokens allowed.
|
||||
search_domain_filter (Union[list[str], None]): Domain filters for search.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is missing, model is empty, max_tokens is not positive,
|
||||
or search_domain_filter is not a list.
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("Perplexity API key is missing")
|
||||
if not model:
|
||||
raise ValueError("model cannot be empty")
|
||||
if max_tokens <= 0:
|
||||
raise ValueError("max_tokens must be positive")
|
||||
if search_domain_filter is not None and not isinstance(search_domain_filter, list):
|
||||
raise ValueError("search_domain_filter must be a list")
|
||||
|
||||
def _execute_query(self, payload: dict[str, Any]) -> "PerplexityChatCompletionResponse":
|
||||
"""
|
||||
Executes a query by sending a POST request to the Perplexity API.
|
||||
|
||||
Args:
|
||||
payload (dict[str, Any]): The payload to send in the API request.
|
||||
|
||||
Returns:
|
||||
PerplexityChatCompletionResponse: Parsed response from the Perplexity API.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is a network error, HTTP error, JSON parsing error, or if the response
|
||||
cannot be parsed into a PerplexityChatCompletionResponse.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.request("POST", self.url, json=payload, headers=headers, timeout=10)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise RuntimeError(
|
||||
f"Perplexity API => Request timed out: {response.text}. Status code: {response.status_code}"
|
||||
) from e
|
||||
except requests.exceptions.HTTPError as e:
|
||||
raise RuntimeError(
|
||||
f"Perplexity API => HTTP error occurred: {response.text}. Status code: {response.status_code}"
|
||||
) from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(
|
||||
f"Perplexity API => Error during request: {response.text}. Status code: {response.status_code}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
response_json = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise RuntimeError(f"Perplexity API => Invalid JSON response received. Error: {e}") from e
|
||||
|
||||
try:
|
||||
# This may raise a pydantic.ValidationError if the response structure is not as expected.
|
||||
perp_resp = PerplexityChatCompletionResponse(**response_json)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError("Perplexity API => Validation error when parsing API response: " + str(e)) from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Perplexity API => Failed to parse API response into PerplexityChatCompletionResponse: " + str(e)
|
||||
) from e
|
||||
|
||||
return perp_resp
|
||||
|
||||
def search(self, query: str) -> "SearchResponse":
|
||||
"""
|
||||
Perform a search query using the Perplexity AI API.
|
||||
|
||||
Constructs the payload, executes the query, and parses the response to return
|
||||
a concise search result along with any provided citations.
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
|
||||
Returns:
|
||||
SearchResponse: A model containing the search result content and citations.
|
||||
|
||||
Raises:
|
||||
ValueError: If the search query is invalid.
|
||||
RuntimeError: If there is an error during the search process.
|
||||
"""
|
||||
if not query or not isinstance(query, str):
|
||||
raise ValueError("A valid non-empty query string must be provided.")
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "system", "content": "Be precise and concise."}, {"role": "user", "content": query}],
|
||||
"max_tokens": self.max_tokens,
|
||||
"search_domain_filter": self.search_domain_filters,
|
||||
"web_search_options": {"search_context_size": "high"},
|
||||
}
|
||||
|
||||
try:
|
||||
perplexity_response = self._execute_query(payload)
|
||||
content = perplexity_response.choices[0].message.content
|
||||
citations = perplexity_response.citations
|
||||
return SearchResponse(content=content, citations=citations, error=None)
|
||||
except Exception as e:
|
||||
# Return a SearchResponse with an error message if something goes wrong.
|
||||
return SearchResponse(content=None, citations=None, error=f"{e}")
|
||||
@@ -0,0 +1,10 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from .reliable import ReliableTool, ReliableToolError, SuccessfulExecutionParameters, ToolExecutionDetails
|
||||
|
||||
__all__ = ["ReliableTool", "ReliableToolError", "SuccessfulExecutionParameters", "ToolExecutionDetails"]
|
||||
1316
mm_agents/coact/autogen/tools/experimental/reliable/reliable.py
Normal file
1316
mm_agents/coact/autogen/tools/experimental/reliable/reliable.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .tavily_search import TavilySearchTool
|
||||
|
||||
__all__ = ["TavilySearchTool"]
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Annotated, Any, Optional, Union
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from ... import Depends, Tool
|
||||
from ...dependency_injection import on
|
||||
|
||||
with optional_import_block():
|
||||
from tavily import TavilyClient
|
||||
|
||||
|
||||
@require_optional_import(
|
||||
[
|
||||
"tavily",
|
||||
],
|
||||
"tavily",
|
||||
)
|
||||
def _execute_tavily_query(
|
||||
query: str,
|
||||
tavily_api_key: str,
|
||||
search_depth: str = "basic",
|
||||
topic: str = "general",
|
||||
include_answer: str = "basic",
|
||||
include_raw_content: bool = False,
|
||||
include_domains: list[str] = [],
|
||||
num_results: int = 5,
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a search query using the Tavily API.
|
||||
|
||||
Args:
|
||||
query (str): The search query string.
|
||||
tavily_api_key (str): The API key for Tavily.
|
||||
search_depth (str, optional): The depth of the search ('basic' or 'advanced'). Defaults to "basic".
|
||||
topic (str, optional): The topic of the search. Defaults to "general".
|
||||
include_answer (str, optional): Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
|
||||
include_raw_content (bool, optional): Whether to include raw content in the results. Defaults to False.
|
||||
include_domains (list[str], optional): A list of domains to include in the search. Defaults to [].
|
||||
num_results (int, optional): The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
Any: The raw response object from the Tavily API client.
|
||||
"""
|
||||
tavily_client = TavilyClient(api_key=tavily_api_key)
|
||||
return tavily_client.search(
|
||||
query=query,
|
||||
search_depth=search_depth,
|
||||
topic=topic,
|
||||
include_answer=include_answer,
|
||||
include_raw_content=include_raw_content,
|
||||
include_domains=include_domains,
|
||||
max_results=num_results,
|
||||
)
|
||||
|
||||
|
||||
def _tavily_search(
|
||||
query: str,
|
||||
tavily_api_key: str,
|
||||
search_depth: str = "basic",
|
||||
topic: str = "general",
|
||||
include_answer: str = "basic",
|
||||
include_raw_content: bool = False,
|
||||
include_domains: list[str] = [],
|
||||
num_results: int = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Perform a Tavily search and format the results.
|
||||
|
||||
This function takes search parameters, executes the query using `_execute_tavily_query`,
|
||||
and formats the results into a list of dictionaries containing title, link, and snippet.
|
||||
|
||||
Args:
|
||||
query (str): The search query string.
|
||||
tavily_api_key (str): The API key for Tavily.
|
||||
search_depth (str, optional): The depth of the search ('basic' or 'advanced'). Defaults to "basic".
|
||||
topic (str, optional): The topic of the search. Defaults to "general".
|
||||
include_answer (str, optional): Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
|
||||
include_raw_content (bool, optional): Whether to include raw content in the results. Defaults to False.
|
||||
include_domains (list[str], optional): A list of domains to include in the search. Defaults to [].
|
||||
num_results (int, optional): The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: A list of dictionaries, where each dictionary represents a search result
|
||||
with keys 'title', 'link', and 'snippet'. Returns an empty list if no results are found.
|
||||
"""
|
||||
res = _execute_tavily_query(
|
||||
query=query,
|
||||
tavily_api_key=tavily_api_key,
|
||||
search_depth=search_depth,
|
||||
topic=topic,
|
||||
include_answer=include_answer,
|
||||
include_raw_content=include_raw_content,
|
||||
include_domains=include_domains,
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
return [
|
||||
{"title": item.get("title", ""), "link": item.get("url", ""), "snippet": item.get("content", "")}
|
||||
for item in res.get("results", [])
|
||||
]
|
||||
|
||||
|
||||
@export_module("autogen.tools.experimental")
|
||||
class TavilySearchTool(Tool):
|
||||
"""
|
||||
TavilySearchTool is a tool that uses the Tavily Search API to perform a search.
|
||||
|
||||
This tool allows agents to leverage the Tavily search engine for information retrieval.
|
||||
It requires a Tavily API key, which can be provided during initialization or set as
|
||||
an environment variable `TAVILY_API_KEY`.
|
||||
|
||||
Attributes:
|
||||
tavily_api_key (str): The API key used for authenticating with the Tavily API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, *, llm_config: Optional[Union[LLMConfig, dict[str, Any]]] = None, tavily_api_key: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initializes the TavilySearchTool.
|
||||
|
||||
Args:
|
||||
llm_config (Optional[Union[LLMConfig, dict[str, Any]]]): LLM configuration. (Currently unused but kept for potential future integration).
|
||||
tavily_api_key (Optional[str]): The API key for the Tavily Search API. If not provided,
|
||||
it attempts to read from the `TAVILY_API_KEY` environment variable.
|
||||
|
||||
Raises:
|
||||
ValueError: If `tavily_api_key` is not provided either directly or via the environment variable.
|
||||
"""
|
||||
self.tavily_api_key = tavily_api_key or os.getenv("TAVILY_API_KEY")
|
||||
|
||||
if self.tavily_api_key is None:
|
||||
raise ValueError("tavily_api_key must be provided either as an argument or via TAVILY_API_KEY env var")
|
||||
|
||||
def tavily_search(
|
||||
query: Annotated[str, "The search query."],
|
||||
tavily_api_key: Annotated[Optional[str], Depends(on(self.tavily_api_key))],
|
||||
search_depth: Annotated[Optional[str], "Either 'advanced' or 'basic'"] = "basic",
|
||||
include_answer: Annotated[Optional[str], "Either 'advanced' or 'basic'"] = "basic",
|
||||
include_raw_content: Annotated[Optional[bool], "Include the raw contents"] = False,
|
||||
include_domains: Annotated[Optional[list[str]], "Specific web domains to search"] = [],
|
||||
num_results: Annotated[int, "The number of results to return."] = 5,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Performs a search using the Tavily API and returns formatted results.
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
tavily_api_key: The API key for Tavily (injected dependency).
|
||||
search_depth: The depth of the search ('basic' or 'advanced'). Defaults to "basic".
|
||||
include_answer: Whether to include an AI-generated answer ('basic' or 'advanced'). Defaults to "basic".
|
||||
include_raw_content: Whether to include raw content in the results. Defaults to False.
|
||||
include_domains: A list of domains to include in the search. Defaults to [].
|
||||
num_results: The maximum number of results to return. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each containing 'title', 'link', and 'snippet' of a search result.
|
||||
|
||||
Raises:
|
||||
ValueError: If the Tavily API key is not available.
|
||||
"""
|
||||
if tavily_api_key is None:
|
||||
raise ValueError("Tavily API key is missing.")
|
||||
return _tavily_search(
|
||||
query=query,
|
||||
tavily_api_key=tavily_api_key,
|
||||
search_depth=search_depth or "basic",
|
||||
include_answer=include_answer or "basic",
|
||||
include_raw_content=include_raw_content or False,
|
||||
include_domains=include_domains or [],
|
||||
num_results=num_results,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
name="tavily_search",
|
||||
description="Use the Tavily Search API to perform a search.",
|
||||
func_or_tool=tavily_search,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .web_search_preview import WebSearchPreviewTool
|
||||
|
||||
__all__ = ["WebSearchPreviewTool"]
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Annotated, Any, Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ....doc_utils import export_module
|
||||
from ....import_utils import optional_import_block, require_optional_import
|
||||
from ....llm_config import LLMConfig
|
||||
from ... import Tool
|
||||
|
||||
with optional_import_block():
|
||||
from openai import OpenAI
|
||||
from openai.types.responses import WebSearchToolParam
|
||||
from openai.types.responses.web_search_tool import UserLocation
|
||||
|
||||
|
||||
@require_optional_import("openai>=1.66.2", "openai")
|
||||
@export_module("autogen.tools.experimental")
|
||||
class WebSearchPreviewTool(Tool):
|
||||
"""WebSearchPreviewTool is a tool that uses OpenAI's web_search_preview tool to perform a search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
llm_config: Union[LLMConfig, dict[str, Any]],
|
||||
search_context_size: Literal["low", "medium", "high"] = "medium",
|
||||
user_location: Optional[dict[str, str]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
text_format: Optional[Type[BaseModel]] = None,
|
||||
):
|
||||
"""Initialize the WebSearchPreviewTool.
|
||||
|
||||
Args:
|
||||
llm_config: The LLM configuration to use. This should be a dictionary
|
||||
containing the model name and other parameters.
|
||||
search_context_size: The size of the search context. One of `low`, `medium`, or `high`.
|
||||
`medium` is the default.
|
||||
user_location: The location of the user. This should be a dictionary containing
|
||||
the city, country, region, and timezone.
|
||||
instructions: Inserts a system (or developer) message as the first item in the model's context.
|
||||
text_format: The format of the text to be returned. This should be a subclass of `BaseModel`.
|
||||
The default is `None`, which means the text will be returned as a string.
|
||||
"""
|
||||
self.web_search_tool_param = WebSearchToolParam(
|
||||
type="web_search_preview",
|
||||
search_context_size=search_context_size,
|
||||
user_location=UserLocation(**user_location) if user_location else None, # type: ignore[typeddict-item]
|
||||
)
|
||||
self.instructions = instructions
|
||||
self.text_format = text_format
|
||||
|
||||
if isinstance(llm_config, LLMConfig):
|
||||
llm_config = llm_config.model_dump()
|
||||
|
||||
llm_config = copy.deepcopy(llm_config)
|
||||
|
||||
if "config_list" not in llm_config:
|
||||
raise ValueError("llm_config must contain 'config_list' key")
|
||||
|
||||
# Find first OpenAI model which starts with "gpt-4"
|
||||
self.model = None
|
||||
self.api_key = None
|
||||
for model in llm_config["config_list"]:
|
||||
if model["model"].startswith("gpt-4") and model.get("api_type", "openai") == "openai":
|
||||
self.model = model["model"]
|
||||
self.api_key = model.get("api_key", os.getenv("OPENAI_API_KEY"))
|
||||
break
|
||||
if self.model is None:
|
||||
raise ValueError(
|
||||
"No OpenAI model starting with 'gpt-4' found in llm_config, other models do not support web_search_preview"
|
||||
)
|
||||
|
||||
if not self.model.startswith("gpt-4.1") and not self.model.startswith("gpt-4o-search-preview"):
|
||||
logging.warning(
|
||||
f"We recommend using a model starting with 'gpt-4.1' or 'gpt-4o-search-preview' for web_search_preview, but found {self.model}. "
|
||||
"This may result in suboptimal performance."
|
||||
)
|
||||
|
||||
def web_search_preview(
|
||||
query: Annotated[str, "The search query. Add all relevant context to the query."],
|
||||
) -> Union[str, Optional[BaseModel]]:
|
||||
client = OpenAI()
|
||||
|
||||
if not self.text_format:
|
||||
response = client.responses.create(
|
||||
model=self.model, # type: ignore[arg-type]
|
||||
tools=[self.web_search_tool_param],
|
||||
input=query,
|
||||
instructions=self.instructions,
|
||||
)
|
||||
return response.output_text
|
||||
|
||||
else:
|
||||
response = client.responses.parse(
|
||||
model=self.model, # type: ignore[arg-type]
|
||||
tools=[self.web_search_tool_param],
|
||||
input=query,
|
||||
instructions=self.instructions,
|
||||
text_format=self.text_format,
|
||||
)
|
||||
return response.output_parsed
|
||||
|
||||
super().__init__(
|
||||
name="web_search_preview",
|
||||
description="Tool used to perform a web search. It can be used as google search or directly searching a specific website.",
|
||||
func_or_tool=web_search_preview,
|
||||
)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .wikipedia import WikipediaPageLoadTool, WikipediaQueryRunTool
|
||||
|
||||
__all__ = ["WikipediaPageLoadTool", "WikipediaQueryRunTool"]
|
||||
@@ -0,0 +1,287 @@
|
||||
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogen.import_utils import optional_import_block, require_optional_import
|
||||
from autogen.tools import Tool
|
||||
|
||||
with optional_import_block():
|
||||
import wikipediaapi
|
||||
|
||||
# Maximum allowed length for a query string.
|
||||
MAX_QUERY_LENGTH = 300
|
||||
# Maximum number of pages to retrieve from a search.
|
||||
MAX_PAGE_RETRIEVE = 100
|
||||
# Maximum number of characters to return from a Wikipedia page.
|
||||
MAX_ARTICLE_LENGTH = 10000
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""Pydantic model representing a Wikipedia document.
|
||||
|
||||
Attributes:
|
||||
page_content (str): Textual content of the Wikipedia page
|
||||
(possibly truncated).
|
||||
metadata (dict[str, str]): Additional info, including:
|
||||
- source URL
|
||||
- title
|
||||
- pageid
|
||||
- timestamp
|
||||
- word count
|
||||
- size
|
||||
"""
|
||||
|
||||
page_content: str
|
||||
metadata: dict[str, str]
|
||||
|
||||
|
||||
class WikipediaClient:
|
||||
"""Client for interacting with the Wikipedia API.
|
||||
|
||||
Supports searching and page retrieval on a specified language edition.
|
||||
|
||||
Public methods:
|
||||
search(query: str, limit: int) -> list[dict[str, Any]]
|
||||
get_page(title: str) -> Optional[wikipediaapi.WikipediaPage]
|
||||
|
||||
Attributes:
|
||||
base_url (str): URL of the MediaWiki API endpoint.
|
||||
headers (dict[str, str]): HTTP headers, including User-Agent.
|
||||
wiki (wikipediaapi.Wikipedia): Low-level Wikipedia API client.
|
||||
"""
|
||||
|
||||
def __init__(self, language: str = "en", tool_name: str = "wikipedia-client") -> None:
|
||||
"""Initialize the WikipediaClient.
|
||||
|
||||
Args:
|
||||
language (str): ISO code of the Wikipedia edition (e.g., 'en', 'es').
|
||||
tool_name (str): Identifier for User-Agent header.
|
||||
"""
|
||||
self.base_url = f"https://{language}.wikipedia.org/w/api.php"
|
||||
self.headers = {"User-Agent": f"autogen.Agent ({tool_name})"}
|
||||
self.wiki = wikipediaapi.Wikipedia(
|
||||
language=language,
|
||||
extract_format=wikipediaapi.ExtractFormat.WIKI,
|
||||
user_agent=f"autogen.Agent ({tool_name})",
|
||||
)
|
||||
|
||||
def search(self, query: str, limit: int = 3) -> Any:
|
||||
"""Search Wikipedia for pages matching a query string.
|
||||
|
||||
Args:
|
||||
query (str): The search keywords.
|
||||
limit (int): Max number of results to return.
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]: Each dict has keys:
|
||||
- 'title' (str)
|
||||
- 'size' (int)
|
||||
- 'wordcount' (int)
|
||||
- 'timestamp' (str)
|
||||
|
||||
Raises:
|
||||
requests.HTTPError: If the HTTP request to the API fails.
|
||||
"""
|
||||
params = {
|
||||
"action": "query",
|
||||
"format": "json",
|
||||
"list": "search",
|
||||
"srsearch": query,
|
||||
"srlimit": str(limit),
|
||||
"srprop": "size|wordcount|timestamp",
|
||||
}
|
||||
|
||||
response = requests.get(url=self.base_url, params=params, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
search_data = data.get("query", {}).get("search", [])
|
||||
return search_data
|
||||
|
||||
def get_page(self, title: str) -> Optional[Any]:
|
||||
"""Retrieve a WikipediaPage object by title.
|
||||
|
||||
Args:
|
||||
title (str): Title of the Wikipedia page.
|
||||
|
||||
Returns:
|
||||
wikipediaapi.WikipediaPage | None: The page object if it exists,
|
||||
otherwise None.
|
||||
|
||||
Raises:
|
||||
wikipediaapi.WikipediaException: On lower‑level API errors.
|
||||
"""
|
||||
page = self.wiki.page(title)
|
||||
if not page.exists():
|
||||
return None
|
||||
return page
|
||||
|
||||
|
||||
@require_optional_import(["wikipediaapi"], "wikipedia")
|
||||
class WikipediaQueryRunTool(Tool):
|
||||
"""Tool for querying Wikipedia and returning summarized page results.
|
||||
|
||||
This tool uses the `wikipediaapi` package to perform searches
|
||||
against a specified language edition of Wikipedia and returns
|
||||
up to `top_k` page summaries.
|
||||
|
||||
Public methods:
|
||||
query_run(query: str) -> list[str] | str
|
||||
|
||||
Attributes:
|
||||
language (str): Language code for the Wikipedia edition (e.g., 'en', 'es').
|
||||
top_k (int): Max number of page summaries returned (≤ MAX_PAGE_RETRIEVE).
|
||||
verbose (bool): If True, enables debug logging to stdout.
|
||||
wiki_cli (WikipediaClient): Internal client for Wikipedia API calls.
|
||||
"""
|
||||
|
||||
def __init__(self, language: str = "en", top_k: int = 3, verbose: bool = False) -> None:
|
||||
"""Initialize the WikipediaQueryRunTool.
|
||||
|
||||
Args:
|
||||
language (str): ISO code of the Wikipedia edition to query.
|
||||
top_k (int): Desired number of summaries (capped by MAX_PAGE_RETRIEVE).
|
||||
verbose (bool): If True, print debug information during searches.
|
||||
"""
|
||||
self.language = language
|
||||
self.tool_name = "wikipedia-query-run"
|
||||
self.wiki_cli = WikipediaClient(language, self.tool_name)
|
||||
self.top_k = min(top_k, MAX_PAGE_RETRIEVE)
|
||||
self.verbose = verbose
|
||||
super().__init__(
|
||||
name=self.tool_name,
|
||||
description="Run a Wikipedia query and return page summaries.",
|
||||
func_or_tool=self.query_run,
|
||||
)
|
||||
|
||||
def query_run(self, query: str) -> Union[list[str], str]:
|
||||
"""Search Wikipedia and return formatted page summaries.
|
||||
|
||||
Truncates `query` to MAX_QUERY_LENGTH before searching.
|
||||
|
||||
Args:
|
||||
query (str): Search term(s) to look up in Wikipedia.
|
||||
|
||||
Returns:
|
||||
list[str]: Each element is "Page: <title>\nSummary: <text>".
|
||||
str: Error message if no results are found or on exception.
|
||||
|
||||
Note:
|
||||
Automatically handles API exceptions and returns error strings for robust operation
|
||||
"""
|
||||
try:
|
||||
if self.verbose:
|
||||
print(f"INFO\t [{self.tool_name}] search query='{query[:MAX_QUERY_LENGTH]}' top_k={self.top_k}")
|
||||
search_results = self.wiki_cli.search(query[:MAX_QUERY_LENGTH], limit=self.top_k)
|
||||
summaries: list[str] = []
|
||||
for item in search_results:
|
||||
title = item["title"]
|
||||
page = self.wiki_cli.get_page(title)
|
||||
# Only format the summary if the page exists and has a summary.
|
||||
if page is not None and page.summary:
|
||||
summary = f"Page: {title}\nSummary: {page.summary}"
|
||||
summaries.append(summary)
|
||||
if not summaries:
|
||||
return "No good Wikipedia Search Result was found"
|
||||
return summaries
|
||||
except Exception as e:
|
||||
return f"wikipedia search failed: {str(e)}"
|
||||
|
||||
|
||||
@require_optional_import(["wikipediaapi"], "wikipedia")
|
||||
class WikipediaPageLoadTool(Tool):
|
||||
"""
|
||||
A tool to load up to N characters of Wikipedia page content along with metadata.
|
||||
|
||||
This tool uses a language-specific Wikipedia client to search for relevant articles
|
||||
and returns a list of Document objects containing truncated page content and metadata
|
||||
(source URL, title, page ID, timestamp, word count, and size). Ideal for agents
|
||||
requiring structured Wikipedia data for research, summarization, or contextual enrichment.
|
||||
|
||||
Attributes:
|
||||
language (str): Wikipedia language code (default: "en").
|
||||
top_k (int): Maximum number of pages to retrieve per query (default: 3).
|
||||
truncate (int): Maximum number of characters of content per page (default: 4000).
|
||||
verbose (bool): If True, prints debug information (default: False).
|
||||
tool_name (str): Identifier used in User-Agent header.
|
||||
wiki_cli (WikipediaClient): Client for interacting with the Wikipedia API.
|
||||
"""
|
||||
|
||||
def __init__(self, language: str = "en", top_k: int = 3, truncate: int = 4000, verbose: bool = False) -> None:
|
||||
"""
|
||||
Initializes the WikipediaPageLoadTool with configurable language, result count, and content length.
|
||||
|
||||
Args:
|
||||
language (str): The language code for the Wikipedia edition (default is "en").
|
||||
top_k (int): The maximum number of pages to retrieve per query (default is 3;
|
||||
capped at MAX_PAGE_RETRIEVE).
|
||||
truncate (int): The maximum number of characters to extract from each page (default is 4000;
|
||||
capped at MAX_ARTICLE_LENGTH).
|
||||
verbose (bool): If True, enables verbose/debug logging (default is False).
|
||||
"""
|
||||
self.language = language
|
||||
self.top_k = min(top_k, MAX_PAGE_RETRIEVE)
|
||||
self.truncate = min(truncate, MAX_ARTICLE_LENGTH)
|
||||
self.verbose = verbose
|
||||
self.tool_name = "wikipedia-page-load"
|
||||
self.wiki_cli = WikipediaClient(language, self.tool_name)
|
||||
super().__init__(
|
||||
name=self.tool_name,
|
||||
description=(
|
||||
"Search Wikipedia for relevant pages using a language-specific client. "
|
||||
"Returns a list of documents with truncated content and metadata including title, URL, "
|
||||
"page ID, timestamp, word count, and page size. Configure number of results with the 'top_k' parameter "
|
||||
"and content length with 'truncate'. Useful for research, summarization, or contextual enrichment."
|
||||
),
|
||||
func_or_tool=self.content_search,
|
||||
)
|
||||
|
||||
def content_search(self, query: str) -> Union[list[Document], str]:
|
||||
"""
|
||||
Executes a Wikipedia search and returns page content plus metadata.
|
||||
|
||||
Args:
|
||||
query (str): The search term to query Wikipedia.
|
||||
|
||||
Returns:
|
||||
Union[list[Document], str]:
|
||||
- list[Document]: Documents with up to `truncate` characters of page text
|
||||
and metadata if pages are found.
|
||||
- str: Error message if the search fails or no pages are found.
|
||||
|
||||
Notes:
|
||||
- Errors are caught internally and returned as strings.
|
||||
- If no matching pages have text content, returns
|
||||
"No good Wikipedia Search Result was found".
|
||||
"""
|
||||
try:
|
||||
if self.verbose:
|
||||
print(f"INFO\t [{self.tool_name}] search query='{query[:MAX_QUERY_LENGTH]}' top_k={self.top_k}")
|
||||
search_results = self.wiki_cli.search(query[:MAX_QUERY_LENGTH], limit=self.top_k)
|
||||
docs: list[Document] = []
|
||||
for item in search_results:
|
||||
page = self.wiki_cli.get_page(item["title"])
|
||||
# Only process pages that exist and have text content.
|
||||
if page is not None and page.text:
|
||||
document = Document(
|
||||
page_content=page.text[: self.truncate],
|
||||
metadata={
|
||||
"source": f"https://{self.language}.wikipedia.org/?curid={item['pageid']}",
|
||||
"title": item["title"],
|
||||
"pageid": str(item["pageid"]),
|
||||
"timestamp": str(item["timestamp"]),
|
||||
"wordcount": str(item["wordcount"]),
|
||||
"size": str(item["size"]),
|
||||
},
|
||||
)
|
||||
docs.append(document)
|
||||
if not docs:
|
||||
return "No good Wikipedia Search Result was found"
|
||||
return docs
|
||||
|
||||
except Exception as e:
|
||||
return f"wikipedia search failed: {str(e)}"
|
||||
Reference in New Issue
Block a user